rahul7star commited on
Commit
05fcd0f
·
verified ·
1 Parent(s): f3c56e6

Migrated from GitHub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .dockerignore +6 -0
  2. Dockerfile +28 -0
  3. LICENSE +201 -0
  4. ORIGINAL_README.md +83 -0
  5. diffusers_helper/bucket_tools.py +97 -0
  6. diffusers_helper/clip_vision.py +12 -0
  7. diffusers_helper/dit_common.py +53 -0
  8. diffusers_helper/gradio/progress_bar.py +86 -0
  9. diffusers_helper/hf_login.py +21 -0
  10. diffusers_helper/hunyuan.py +163 -0
  11. diffusers_helper/k_diffusion/uni_pc_fm.py +144 -0
  12. diffusers_helper/k_diffusion/wrapper.py +51 -0
  13. diffusers_helper/lora_utils.py +194 -0
  14. diffusers_helper/memory.py +134 -0
  15. diffusers_helper/models/hunyuan_video_packed.py +1062 -0
  16. diffusers_helper/models/mag_cache.py +219 -0
  17. diffusers_helper/models/mag_cache_ratios.py +71 -0
  18. diffusers_helper/pipelines/k_diffusion_hunyuan.py +120 -0
  19. diffusers_helper/thread_utils.py +76 -0
  20. diffusers_helper/utils.py +613 -0
  21. docker-compose.yml +25 -0
  22. install.bat +208 -0
  23. modules/__init__.py +4 -0
  24. modules/generators/__init__.py +32 -0
  25. modules/generators/base_generator.py +281 -0
  26. modules/generators/f1_generator.py +235 -0
  27. modules/generators/original_generator.py +213 -0
  28. modules/generators/original_with_endframe_generator.py +15 -0
  29. modules/generators/video_base_generator.py +613 -0
  30. modules/generators/video_f1_generator.py +189 -0
  31. modules/generators/video_generator.py +239 -0
  32. modules/grid_builder.py +78 -0
  33. modules/interface.py +0 -0
  34. modules/llm_captioner.py +66 -0
  35. modules/llm_enhancer.py +191 -0
  36. modules/pipelines/__init__.py +45 -0
  37. modules/pipelines/base_pipeline.py +85 -0
  38. modules/pipelines/f1_pipeline.py +140 -0
  39. modules/pipelines/metadata_utils.py +329 -0
  40. modules/pipelines/original_pipeline.py +138 -0
  41. modules/pipelines/original_with_endframe_pipeline.py +157 -0
  42. modules/pipelines/video_f1_pipeline.py +143 -0
  43. modules/pipelines/video_pipeline.py +143 -0
  44. modules/pipelines/video_tools.py +57 -0
  45. modules/pipelines/worker.py +1150 -0
  46. modules/prompt_handler.py +164 -0
  47. modules/settings.py +88 -0
  48. modules/toolbox/RIFE/IFNet_HDv3.py +136 -0
  49. modules/toolbox/RIFE/RIFE_HDv3.py +98 -0
  50. modules/toolbox/RIFE/__int__.py +0 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ hf_download/
2
+ .framepack/
3
+ loras/
4
+ outputs/
5
+ modules/toolbox/model_esrgan/
6
+ modules/toolbox/model_rife/
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG CUDA_VERSION=12.4.1
2
+
3
+ FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu22.04
4
+
5
+ ARG CUDA_VERSION
6
+
7
+ RUN apt-get update && apt-get install -y \
8
+ python3 python3-pip git ffmpeg wget curl && \
9
+ pip3 install --upgrade pip
10
+
11
+ WORKDIR /app
12
+
13
+ # This allows caching pip install if only code has changed
14
+ COPY requirements.txt .
15
+
16
+ # Install dependencies
17
+ RUN pip3 install --no-cache-dir -r requirements.txt
18
+ RUN export CUDA_SHORT_VERSION=$(echo "${CUDA_VERSION}" | sed 's/\.//g' | cut -c 1-3) && \
19
+ pip3 install --force-reinstall --no-cache-dir torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cu${CUDA_SHORT_VERSION}"
20
+
21
+ # Copy the source code to /app
22
+ COPY . .
23
+
24
+ VOLUME [ "/app/.framepack", "/app/outputs", "/app/loras", "/app/hf_download", "/app/modules/toolbox/model_esrgan", "/app/modules/toolbox/model_rife" ]
25
+
26
+ EXPOSE 7860
27
+
28
+ CMD ["python3", "studio.py"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">FramePack Studio</h1>
2
+
3
+ [![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/MtuM7gFJ3V)[![Patreon](https://img.shields.io/badge/Patreon-F96854?style=for-the-badge&logo=patreon&logoColor=white)](https://www.patreon.com/ColinU)
4
+
5
+ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/colinurbs/FramePack-Studio)
6
+
7
+ FramePack Studio is an AI video generation application based on FramePack that strives to provide everything you need to create high quality video projects.
8
+
9
+ ![screencapture-127-0-0-1-7860-2025-06-12-19_50_37](https://github.com/user-attachments/assets/b86a8422-f4ce-452b-80eb-2ba91945f2ea)
10
+ ![screencapture-127-0-0-1-7860-2025-06-12-19_52_33](https://github.com/user-attachments/assets/ebfb31ca-85b7-4354-87c6-aaab6d1c77b1)
11
+
12
+ ## Current Features
13
+
14
+ - **F1, Original and Video Extension Generations**: Run all in a single queue
15
+ - **End Frame Control for 'Original' Model**: Provides greater control over generations
16
+ - **Upscaling and Post-processing**
17
+ - **Timestamped Prompts**: Define different prompts for specific time segments in your video
18
+ - **Prompt Blending**: Define the blending time between timestamped prompts
19
+ - **LoRA Support**: Works with most (all?) Hunyuan Video LoRAs
20
+ - **Queue System**: Process multiple generation jobs without blocking the interface. Import and export queues.
21
+ - **Metadata Saving/Import**: Prompt and seed are encoded into the output PNG, all other generation metadata is saved in a JSON file that can be imported later for similar generations.
22
+ - **Custom Presets**: Allow quick switching between named groups of parameters. A custom Startup Preset can also be set.
23
+ - **I2V and T2V**: Works with or without an input image to allow for more flexibility when working with standard Hunyuan Video LoRAs
24
+ - **Latent Image Options**: When using T2V you can generate based on a black, white, green screen, or pure noise image
25
+
26
+ ## Prerequisites
27
+
28
+ - CUDA-compatible GPU with at least 8GB VRAM (16GB+ recommended)
29
+ - 16GB System Memory (32GB+ strongly recommended)
30
+ - 80GB+ of storage (including ~25GB for each model family: Original and F1)
31
+
32
+ ## Documentation
33
+
34
+ For information on installation, configuration, and usage, please visit our [documentation site](https://docs.framepackstudio.com/).
35
+
36
+ ## Installation
37
+
38
+ Please see [this guide](https://docs.framepackstudio.com/docs/get_started/) on our documentation site to get FP-Studio installed.
39
+
40
+ ## LoRAs
41
+
42
+ Add LoRAs to the /loras/ folder at the root of the installation. Select the LoRAs you wish to load and set the weights for each generation. Most Hunyuan LoRAs were originally trained for T2V, it's often helpful to run a T2V generation to ensure they're working before using input images.
43
+
44
+ NOTE: Slow lora loading is a known issue
45
+
46
+ ## Working with Timestamped Prompts
47
+
48
+ You can create videos with changing prompts over time using the following syntax:
49
+
50
+ ```
51
+ [0s: A serene forest with sunlight filtering through the trees ]
52
+ [5s: A deer appears in the clearing ]
53
+ [10s: The deer drinks from a small stream ]
54
+ ```
55
+
56
+ Each timestamp defines when that prompt should start influencing the generation. The system will (hopefully) smoothly transition between prompts for a cohesive video.
57
+
58
+ ## Credits
59
+
60
+ Many thanks to [Lvmin Zhang](https://github.com/lllyasviel) for the absolutely amazing work on the original [FramePack](https://github.com/lllyasviel/FramePack) code!
61
+
62
+ Thanks to [Rickard Edén](https://github.com/neph1) for the LoRA code and their general contributions to this growing FramePack scene!
63
+
64
+ Thanks to [Zehong Ma](https://github.com/Zehong-Ma) for [MagCache](https://github.com/Zehong-Ma/MagCache): Fast Video Generation with Magnitude-Aware Cache!
65
+
66
+ Thanks to everyone who has joined the Discord, reported a bug, sumbitted a PR, or helped with testing!
67
+
68
+ @article{zhang2025framepack,
69
+ title={Packing Input Frame Contexts in Next-Frame Prediction Models for Video Generation},
70
+ author={Lvmin Zhang and Maneesh Agrawala},
71
+ journal={Arxiv},
72
+ year={2025}
73
+ }
74
+
75
+ @misc{zhang2025packinginputframecontext,
76
+ title={Packing Input Frame Context in Next-Frame Prediction Models for Video Generation},
77
+ author={Lvmin Zhang and Maneesh Agrawala},
78
+ year={2025},
79
+ eprint={2504.12626},
80
+ archivePrefix={arXiv},
81
+ primaryClass={cs.CV},
82
+ url={https://arxiv.org/abs/2504.12626}
83
+ }
diffusers_helper/bucket_tools.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 128: [
3
+ (96, 160),
4
+ (112, 144),
5
+ (128, 128),
6
+ (144, 112),
7
+ (160, 96),
8
+ ],
9
+ 256: [
10
+ (192, 320),
11
+ (224, 288),
12
+ (256, 256),
13
+ (288, 224),
14
+ (320, 192),
15
+ ],
16
+ 384: [
17
+ (256, 512),
18
+ (320, 448),
19
+ (384, 384),
20
+ (448, 320),
21
+ (512, 256),
22
+ ],
23
+ 512: [
24
+ (352, 704),
25
+ (384, 640),
26
+ (448, 576),
27
+ (512, 512),
28
+ (576, 448),
29
+ (640, 384),
30
+ (704, 352),
31
+ ],
32
+ 640: [
33
+ (416, 960),
34
+ (448, 864),
35
+ (480, 832),
36
+ (512, 768),
37
+ (544, 704),
38
+ (576, 672),
39
+ (608, 640),
40
+ (640, 640),
41
+ (640, 608),
42
+ (672, 576),
43
+ (704, 544),
44
+ (768, 512),
45
+ (832, 480),
46
+ (864, 448),
47
+ (960, 416),
48
+ ],
49
+ 768: [
50
+ (512, 1024),
51
+ (576, 896),
52
+ (640, 832),
53
+ (704, 768),
54
+ (768, 768),
55
+ (768, 704),
56
+ (832, 640),
57
+ (896, 576),
58
+ (1024, 512),
59
+ ],
60
+ }
61
+
62
+
63
+ def find_nearest_bucket(h, w, resolution=640):
64
+ # Use the provided resolution or find the closest available bucket size
65
+ # print(f"find_nearest_bucket called with h={h}, w={w}, resolution={resolution}")
66
+
67
+ # Convert resolution to int if it's not already
68
+ resolution = int(resolution) if not isinstance(resolution, int) else resolution
69
+
70
+ if resolution not in bucket_options:
71
+ # Find the closest available resolution
72
+ available_resolutions = list(bucket_options.keys())
73
+ closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution))
74
+ # print(f"Resolution {resolution} not found in bucket options, using closest available: {closest_resolution}")
75
+ resolution = closest_resolution
76
+ # else:
77
+ # print(f"Resolution {resolution} found in bucket options")
78
+
79
+ # Calculate the aspect ratio of the input image
80
+ input_aspect_ratio = w / h if h > 0 else 1.0
81
+ # print(f"Input aspect ratio: {input_aspect_ratio:.4f}")
82
+
83
+ min_diff = float('inf')
84
+ best_bucket = None
85
+
86
+ # Find the bucket size with the closest aspect ratio to the input image
87
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
88
+ bucket_aspect_ratio = bucket_w / bucket_h if bucket_h > 0 else 1.0
89
+ # Calculate the difference in aspect ratios
90
+ diff = abs(bucket_aspect_ratio - input_aspect_ratio)
91
+ if diff < min_diff:
92
+ min_diff = diff
93
+ best_bucket = (bucket_h, bucket_w)
94
+ # print(f" Checking bucket ({bucket_h}, {bucket_w}), aspect ratio={bucket_aspect_ratio:.4f}, diff={diff:.4f}, current best={best_bucket}")
95
+
96
+ # print(f"Using resolution {resolution}, selected bucket: {best_bucket}")
97
+ return best_bucket
diffusers_helper/clip_vision.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype)
10
+ image_encoder_output = image_encoder(**preprocessed)
11
+
12
+ return image_encoder_output
diffusers_helper/dit_common.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import accelerate.accelerator
3
+
4
+ from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
5
+
6
+
7
+ accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
8
+
9
+
10
+ def LayerNorm_forward(self, x):
11
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
12
+
13
+
14
+ LayerNorm.forward = LayerNorm_forward
15
+ torch.nn.LayerNorm.forward = LayerNorm_forward
16
+
17
+
18
+ def FP32LayerNorm_forward(self, x):
19
+ origin_dtype = x.dtype
20
+ return torch.nn.functional.layer_norm(
21
+ x.float(),
22
+ self.normalized_shape,
23
+ self.weight.float() if self.weight is not None else None,
24
+ self.bias.float() if self.bias is not None else None,
25
+ self.eps,
26
+ ).to(origin_dtype)
27
+
28
+
29
+ FP32LayerNorm.forward = FP32LayerNorm_forward
30
+
31
+
32
+ def RMSNorm_forward(self, hidden_states):
33
+ input_dtype = hidden_states.dtype
34
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
35
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
36
+
37
+ if self.weight is None:
38
+ return hidden_states.to(input_dtype)
39
+
40
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
41
+
42
+
43
+ RMSNorm.forward = RMSNorm_forward
44
+
45
+
46
+ def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
47
+ emb = self.linear(self.silu(conditioning_embedding))
48
+ scale, shift = emb.chunk(2, dim=1)
49
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
50
+ return x
51
+
52
+
53
+ AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
diffusers_helper/gradio/progress_bar.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ progress_html = '''
2
+ <div class="loader-container">
3
+ <div class="loader"></div>
4
+ <div class="progress-container">
5
+ <progress value="*number*" max="100"></progress>
6
+ </div>
7
+ <span>*text*</span>
8
+ </div>
9
+ '''
10
+
11
+ css = '''
12
+ .loader-container {
13
+ display: flex; /* Use flex to align items horizontally */
14
+ align-items: center; /* Center items vertically within the container */
15
+ white-space: nowrap; /* Prevent line breaks within the container */
16
+ }
17
+
18
+ .loader {
19
+ border: 8px solid #f3f3f3; /* Light grey */
20
+ border-top: 8px solid #3498db; /* Blue */
21
+ border-radius: 50%;
22
+ width: 30px;
23
+ height: 30px;
24
+ animation: spin 2s linear infinite;
25
+ }
26
+
27
+ @keyframes spin {
28
+ 0% { transform: rotate(0deg); }
29
+ 100% { transform: rotate(360deg); }
30
+ }
31
+
32
+ /* Style the progress bar */
33
+ progress {
34
+ appearance: none; /* Remove default styling */
35
+ height: 20px; /* Set the height of the progress bar */
36
+ border-radius: 5px; /* Round the corners of the progress bar */
37
+ background-color: #f3f3f3; /* Light grey background */
38
+ width: 100%;
39
+ vertical-align: middle !important;
40
+ }
41
+
42
+ /* Style the progress bar container */
43
+ .progress-container {
44
+ margin-left: 20px;
45
+ margin-right: 20px;
46
+ flex-grow: 1; /* Allow the progress container to take up remaining space */
47
+ }
48
+
49
+ /* Set the color of the progress bar fill */
50
+ progress::-webkit-progress-value {
51
+ background-color: #3498db; /* Blue color for the fill */
52
+ }
53
+
54
+ progress::-moz-progress-bar {
55
+ background-color: #3498db; /* Blue color for the fill in Firefox */
56
+ }
57
+
58
+ /* Style the text on the progress bar */
59
+ progress::after {
60
+ content: attr(value '%'); /* Display the progress value followed by '%' */
61
+ position: absolute;
62
+ top: 50%;
63
+ left: 50%;
64
+ transform: translate(-50%, -50%);
65
+ color: white; /* Set text color */
66
+ font-size: 14px; /* Set font size */
67
+ }
68
+
69
+ /* Style other texts */
70
+ .loader-container > span {
71
+ margin-left: 5px; /* Add spacing between the progress bar and the text */
72
+ }
73
+
74
+ .no-generating-animation > .generating {
75
+ display: none !important;
76
+ }
77
+
78
+ '''
79
+
80
+
81
+ def make_progress_bar_html(number, text):
82
+ return progress_html.replace('*number*', str(number)).replace('*text*', text)
83
+
84
+
85
+ def make_progress_bar_css():
86
+ return css
diffusers_helper/hf_login.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ def login(token):
5
+ from huggingface_hub import login
6
+ import time
7
+
8
+ while True:
9
+ try:
10
+ login(token)
11
+ print('HF login ok.')
12
+ break
13
+ except Exception as e:
14
+ print(f'HF login failed: {e}. Retrying')
15
+ time.sleep(0.5)
16
+
17
+
18
+ hf_token = os.environ.get('HF_TOKEN', None)
19
+
20
+ if hf_token is not None:
21
+ login(hf_token)
diffusers_helper/hunyuan.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
4
+ from diffusers_helper.utils import crop_or_pad_yield_mask
5
+
6
+
7
+ @torch.no_grad()
8
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256):
9
+ assert isinstance(prompt, str)
10
+
11
+ prompt = [prompt]
12
+
13
+ # LLAMA
14
+
15
+ # Check if there's a custom system prompt template in settings
16
+ custom_template = None
17
+ try:
18
+ from modules.settings import Settings
19
+ settings = Settings()
20
+ override_system_prompt = settings.get("override_system_prompt", False)
21
+ custom_template_str = settings.get("system_prompt_template")
22
+
23
+ if override_system_prompt and custom_template_str:
24
+ try:
25
+ # Convert the string representation to a dictionary
26
+ # Extract template and crop_start directly from the string using regex
27
+ import re
28
+
29
+ # Try to extract the template value
30
+ template_match = re.search(r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})", custom_template_str, re.DOTALL)
31
+ crop_start_match = re.search(r"['\"]crop_start['\"]\s*:\s*(\d+)", custom_template_str)
32
+
33
+ if template_match and crop_start_match:
34
+ template_value = template_match.group(1)
35
+ crop_start_value = int(crop_start_match.group(1))
36
+
37
+ # Unescape any escaped characters in the template
38
+ template_value = template_value.replace("\\n", "\n").replace("\\\"", "\"").replace("\\'", "'")
39
+
40
+ custom_template = {
41
+ "template": template_value,
42
+ "crop_start": crop_start_value
43
+ }
44
+ print(f"Using custom system prompt template from settings: {custom_template}")
45
+ else:
46
+ print(f"Could not extract template or crop_start from system prompt template string")
47
+ print(f"Falling back to default template")
48
+ custom_template = None
49
+ except Exception as e:
50
+ print(f"Error parsing custom system prompt template: {e}")
51
+ print(f"Falling back to default template")
52
+ custom_template = None
53
+ else:
54
+ if not override_system_prompt:
55
+ print(f"Override system prompt is disabled, using default template")
56
+ elif not custom_template_str:
57
+ print(f"No custom system prompt template found in settings")
58
+ custom_template = None
59
+ except Exception as e:
60
+ print(f"Error loading settings: {e}")
61
+ print(f"Falling back to default template")
62
+ custom_template = None
63
+
64
+ # Use custom template if available, otherwise use default
65
+ template = custom_template if custom_template else DEFAULT_PROMPT_TEMPLATE
66
+
67
+ prompt_llama = [template["template"].format(p) for p in prompt]
68
+ crop_start = template["crop_start"]
69
+
70
+ llama_inputs = tokenizer(
71
+ prompt_llama,
72
+ padding="max_length",
73
+ max_length=max_length + crop_start,
74
+ truncation=True,
75
+ return_tensors="pt",
76
+ return_length=False,
77
+ return_overflowing_tokens=False,
78
+ return_attention_mask=True,
79
+ )
80
+
81
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
82
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
83
+ llama_attention_length = int(llama_attention_mask.sum())
84
+
85
+ llama_outputs = text_encoder(
86
+ input_ids=llama_input_ids,
87
+ attention_mask=llama_attention_mask,
88
+ output_hidden_states=True,
89
+ )
90
+
91
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
92
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
93
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
94
+
95
+ assert torch.all(llama_attention_mask.bool())
96
+
97
+ # CLIP
98
+
99
+ clip_l_input_ids = tokenizer_2(
100
+ prompt,
101
+ padding="max_length",
102
+ max_length=77,
103
+ truncation=True,
104
+ return_overflowing_tokens=False,
105
+ return_length=False,
106
+ return_tensors="pt",
107
+ ).input_ids
108
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
109
+
110
+ return llama_vec, clip_l_pooler
111
+
112
+
113
+ @torch.no_grad()
114
+ def vae_decode_fake(latents):
115
+ latent_rgb_factors = [
116
+ [-0.0395, -0.0331, 0.0445],
117
+ [0.0696, 0.0795, 0.0518],
118
+ [0.0135, -0.0945, -0.0282],
119
+ [0.0108, -0.0250, -0.0765],
120
+ [-0.0209, 0.0032, 0.0224],
121
+ [-0.0804, -0.0254, -0.0639],
122
+ [-0.0991, 0.0271, -0.0669],
123
+ [-0.0646, -0.0422, -0.0400],
124
+ [-0.0696, -0.0595, -0.0894],
125
+ [-0.0799, -0.0208, -0.0375],
126
+ [0.1166, 0.1627, 0.0962],
127
+ [0.1165, 0.0432, 0.0407],
128
+ [-0.2315, -0.1920, -0.1355],
129
+ [-0.0270, 0.0401, -0.0821],
130
+ [-0.0616, -0.0997, -0.0727],
131
+ [0.0249, -0.0469, -0.1703]
132
+ ] # From comfyui
133
+
134
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
135
+
136
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
137
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
138
+
139
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
140
+ images = images.clamp(0.0, 1.0)
141
+
142
+ return images
143
+
144
+
145
+ @torch.no_grad()
146
+ def vae_decode(latents, vae, image_mode=False):
147
+ latents = latents / vae.config.scaling_factor
148
+
149
+ if not image_mode:
150
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
151
+ else:
152
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
153
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
154
+ image = torch.cat(image, dim=2)
155
+
156
+ return image
157
+
158
+
159
+ @torch.no_grad()
160
+ def vae_encode(image, vae):
161
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
162
+ latents = latents * vae.config.scaling_factor
163
+ return latents
diffusers_helper/k_diffusion/uni_pc_fm.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Better Flow Matching UniPC by Lvmin Zhang
2
+ # (c) 2025
3
+ # CC BY-SA 4.0
4
+ # Attribution-ShareAlike 4.0 International Licence
5
+
6
+
7
+ import torch
8
+
9
+ from tqdm.auto import trange
10
+
11
+
12
+ def expand_dims(v, dims):
13
+ return v[(...,) + (None,) * (dims - 1)]
14
+
15
+
16
+ class FlowMatchUniPC:
17
+ def __init__(self, model, extra_args, variant='bh1'):
18
+ self.model = model
19
+ self.variant = variant
20
+ self.extra_args = extra_args
21
+
22
+ def model_fn(self, x, t):
23
+ return self.model(x, t, **self.extra_args)
24
+
25
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
+ assert order <= len(model_prev_list)
27
+ dims = x.dim()
28
+
29
+ t_prev_0 = t_prev_list[-1]
30
+ lambda_prev_0 = - torch.log(t_prev_0)
31
+ lambda_t = - torch.log(t)
32
+ model_prev_0 = model_prev_list[-1]
33
+
34
+ h = lambda_t - lambda_prev_0
35
+
36
+ rks = []
37
+ D1s = []
38
+ for i in range(1, order):
39
+ t_prev_i = t_prev_list[-(i + 1)]
40
+ model_prev_i = model_prev_list[-(i + 1)]
41
+ lambda_prev_i = - torch.log(t_prev_i)
42
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
+ rks.append(rk)
44
+ D1s.append((model_prev_i - model_prev_0) / rk)
45
+
46
+ rks.append(1.)
47
+ rks = torch.tensor(rks, device=x.device)
48
+
49
+ R = []
50
+ b = []
51
+
52
+ hh = -h[0]
53
+ h_phi_1 = torch.expm1(hh)
54
+ h_phi_k = h_phi_1 / hh - 1
55
+
56
+ factorial_i = 1
57
+
58
+ if self.variant == 'bh1':
59
+ B_h = hh
60
+ elif self.variant == 'bh2':
61
+ B_h = torch.expm1(hh)
62
+ else:
63
+ raise NotImplementedError('Bad variant!')
64
+
65
+ for i in range(1, order + 1):
66
+ R.append(torch.pow(rks, i - 1))
67
+ b.append(h_phi_k * factorial_i / B_h)
68
+ factorial_i *= (i + 1)
69
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
70
+
71
+ R = torch.stack(R)
72
+ b = torch.tensor(b, device=x.device)
73
+
74
+ use_predictor = len(D1s) > 0
75
+
76
+ if use_predictor:
77
+ D1s = torch.stack(D1s, dim=1)
78
+ if order == 2:
79
+ rhos_p = torch.tensor([0.5], device=b.device)
80
+ else:
81
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
+ else:
83
+ D1s = None
84
+ rhos_p = None
85
+
86
+ if order == 1:
87
+ rhos_c = torch.tensor([0.5], device=b.device)
88
+ else:
89
+ rhos_c = torch.linalg.solve(R, b)
90
+
91
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
+
93
+ if use_predictor:
94
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
+ else:
96
+ pred_res = 0
97
+
98
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
+ model_t = self.model_fn(x_t, t)
100
+
101
+ if D1s is not None:
102
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
+ else:
104
+ corr_res = 0
105
+
106
+ D1_t = (model_t - model_prev_0)
107
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
+
109
+ return x_t, model_t
110
+
111
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
+ order = min(3, len(sigmas) - 2)
113
+ model_prev_list, t_prev_list = [], []
114
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
+ vec_t = sigmas[i].expand(x.shape[0])
116
+
117
+ if i == 0:
118
+ model_prev_list = [self.model_fn(x, vec_t)]
119
+ t_prev_list = [vec_t]
120
+ elif i < order:
121
+ init_order = i
122
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123
+ model_prev_list.append(model_x)
124
+ t_prev_list.append(vec_t)
125
+ else:
126
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127
+ model_prev_list.append(model_x)
128
+ t_prev_list.append(vec_t)
129
+
130
+ model_prev_list = model_prev_list[-order:]
131
+ t_prev_list = t_prev_list[-order:]
132
+
133
+ if callback is not None:
134
+ callback_result = callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135
+ if callback_result == 'cancel':
136
+ print("Cancellation signal received in sample_unipc, stopping generation")
137
+ return model_prev_list[-1] # Return current denoised result
138
+
139
+ return model_prev_list[-1]
140
+
141
+
142
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
143
+ assert variant in ['bh1', 'bh2']
144
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
diffusers_helper/k_diffusion/wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def append_dims(x, target_dims):
5
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
6
+
7
+
8
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
+ if guidance_rescale == 0:
10
+ return noise_cfg
11
+
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
+ return noise_cfg
17
+
18
+
19
+ def fm_wrapper(transformer, t_scale=1000.0):
20
+ def k_model(x, sigma, **extra_args):
21
+ dtype = extra_args['dtype']
22
+ cfg_scale = extra_args['cfg_scale']
23
+ cfg_rescale = extra_args['cfg_rescale']
24
+ concat_latent = extra_args['concat_latent']
25
+
26
+ original_dtype = x.dtype
27
+ sigma = sigma.float()
28
+
29
+ x = x.to(dtype)
30
+ timestep = (sigma * t_scale).to(dtype)
31
+
32
+ if concat_latent is None:
33
+ hidden_states = x
34
+ else:
35
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
+
37
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
+
39
+ if cfg_scale == 1.0:
40
+ pred_negative = torch.zeros_like(pred_positive)
41
+ else:
42
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
+
44
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
+
47
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
+
49
+ return x0.to(dtype=original_dtype)
50
+
51
+ return k_model
diffusers_helper/lora_utils.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path, PurePath
2
+ from typing import Dict, List, Optional, Union, Tuple
3
+ from diffusers.loaders.lora_pipeline import _fetch_state_dict
4
+ from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers
5
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
6
+ from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
7
+ import torch
8
+
9
+ FALLBACK_CLASS_ALIASES = {
10
+ "HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel",
11
+ }
12
+
13
+ def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]:
14
+ """
15
+ Load LoRA weights into the transformer model.
16
+
17
+ Args:
18
+ transformer: The transformer model to which LoRA weights will be applied.
19
+ lora_path: Path to the folder containing the LoRA weights file.
20
+ weight_name: Filename of the weight to load.
21
+
22
+ Returns:
23
+ A tuple containing the modified transformer and the canonical adapter name.
24
+ """
25
+
26
+ state_dict = _fetch_state_dict(
27
+ lora_path,
28
+ weight_name,
29
+ True,
30
+ True,
31
+ None,
32
+ None,
33
+ None,
34
+ None,
35
+ None,
36
+ None,
37
+ None,
38
+ None)
39
+
40
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
41
+
42
+ # should weight_name even be Optional[str] or just str?
43
+ # For now, we assume it is never None
44
+ # The module name in the state_dict must not include a . in the name
45
+ # See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134
46
+ adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_')
47
+ if '_DOT_' in adapter_name:
48
+ print(
49
+ f"LoRA file '{weight_name}' contains a '.' in the name. " +
50
+ 'This may cause issues. Consider renaming the file.' +
51
+ f" Using '{adapter_name}' as the adapter name to be safe."
52
+ )
53
+
54
+ # Check if adapter already exists and delete it if it does
55
+ if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config:
56
+ print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.")
57
+ # Use delete_adapters (plural) instead of delete_adapter
58
+ transformer.delete_adapters([adapter_name])
59
+
60
+ # Load the adapter with the original name
61
+ transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name)
62
+ print(f"LoRA weights '{adapter_name}' loaded successfully.")
63
+
64
+ return transformer, adapter_name
65
+
66
+ def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module:
67
+ """
68
+ Completely unload all LoRA adapters from the transformer model.
69
+
70
+ Args:
71
+ transformer: The transformer model from which LoRA adapters will be removed.
72
+
73
+ Returns:
74
+ The transformer model after all LoRA adapters have been removed.
75
+ """
76
+ if hasattr(transformer, 'peft_config') and transformer.peft_config:
77
+ # Get all adapter names
78
+ adapter_names = list(transformer.peft_config.keys())
79
+
80
+ if adapter_names:
81
+ print(f"Removing all LoRA adapters: {', '.join(adapter_names)}")
82
+ # Delete all adapters
83
+ transformer.delete_adapters(adapter_names)
84
+
85
+ # Force cleanup of any remaining adapter references
86
+ if hasattr(transformer, 'active_adapter'):
87
+ transformer.active_adapter = None
88
+
89
+ # Clear any cached states
90
+ for module in transformer.modules():
91
+ if hasattr(module, 'lora_A'):
92
+ if isinstance(module.lora_A, dict):
93
+ module.lora_A.clear()
94
+ if hasattr(module, 'lora_B'):
95
+ if isinstance(module.lora_B, dict):
96
+ module.lora_B.clear()
97
+ if hasattr(module, 'scaling'):
98
+ if isinstance(module.scaling, dict):
99
+ module.scaling.clear()
100
+
101
+ print("All LoRA adapters have been completely removed.")
102
+ else:
103
+ print("No LoRA adapters found to remove.")
104
+ else:
105
+ print("Model doesn't have any LoRA adapters or peft_config.")
106
+
107
+ # Force garbage collection
108
+ import gc
109
+ gc.collect()
110
+ if torch.cuda.is_available():
111
+ torch.cuda.empty_cache()
112
+
113
+ return transformer
114
+
115
+ def resolve_expansion_class_name(
116
+ transformer: torch.nn.Module,
117
+ fallback_aliases: Dict[str, str],
118
+ fn_mapping: Dict[str, callable]
119
+ ) -> Optional[str]:
120
+ """
121
+ Resolves the canonical class name for adapter scale expansion functions,
122
+ considering potential fallback aliases.
123
+
124
+ Args:
125
+ transformer: The transformer model instance.
126
+ fallback_aliases: A dictionary mapping model class names to fallback class names.
127
+ fn_mapping: A dictionary mapping class names to their respective scale expansion functions.
128
+
129
+ Returns:
130
+ The resolved class name as a string if a matching scale function is found,
131
+ otherwise None.
132
+ """
133
+ class_name = transformer.__class__.__name__
134
+
135
+ if class_name in fn_mapping:
136
+ return class_name
137
+
138
+ fallback_class = fallback_aliases.get(class_name)
139
+ if fallback_class in fn_mapping:
140
+ print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'")
141
+ return fallback_class
142
+
143
+ return None
144
+
145
+ def set_adapters(
146
+ transformer: torch.nn.Module,
147
+ adapter_names: Union[List[str], str],
148
+ weights: Optional[Union[float, List[float]]] = None,
149
+ ):
150
+ """
151
+ Activates and sets the weights for one or more LoRA adapters on the transformer model.
152
+
153
+ Args:
154
+ transformer: The transformer model to which LoRA adapters are applied.
155
+ adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate.
156
+ weights: Optional. A single float weight or a list of float weights
157
+ corresponding to each adapter name. If None, defaults to 1.0 for each adapter.
158
+ If a single float, it will be applied to all adapters.
159
+ """
160
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
161
+
162
+ # Expand a single weight to apply to all adapters if needed
163
+ if not isinstance(weights, list):
164
+ weights = [weights] * len(adapter_names)
165
+
166
+ if len(adapter_names) != len(weights):
167
+ raise ValueError(
168
+ f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})."
169
+ )
170
+
171
+ # Replace any None weights with a default value of 1.0
172
+ sanitized_weights = [w if w is not None else 1.0 for w in weights]
173
+
174
+ resolved_class_name = resolve_expansion_class_name(
175
+ transformer,
176
+ fallback_aliases=FALLBACK_CLASS_ALIASES,
177
+ fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING
178
+ )
179
+
180
+ transformer_class_name = transformer.__class__.__name__
181
+
182
+ if resolved_class_name:
183
+ scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name]
184
+ print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')")
185
+ final_weights = [
186
+ scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights
187
+ ]
188
+ else:
189
+ print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.")
190
+ final_weights = sanitized_weights
191
+
192
+ set_weights_and_activate_adapters(transformer, adapter_names, final_weights)
193
+
194
+ print(f"Adapters {adapter_names} activated with weights {final_weights}.")
diffusers_helper/memory.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # By lllyasviel
2
+
3
+
4
+ import torch
5
+
6
+
7
+ cpu = torch.device('cpu')
8
+ gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
9
+ gpu_complete_modules = []
10
+
11
+
12
+ class DynamicSwapInstaller:
13
+ @staticmethod
14
+ def _install_module(module: torch.nn.Module, **kwargs):
15
+ original_class = module.__class__
16
+ module.__dict__['forge_backup_original_class'] = original_class
17
+
18
+ def hacked_get_attr(self, name: str):
19
+ if '_parameters' in self.__dict__:
20
+ _parameters = self.__dict__['_parameters']
21
+ if name in _parameters:
22
+ p = _parameters[name]
23
+ if p is None:
24
+ return None
25
+ if p.__class__ == torch.nn.Parameter:
26
+ return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
27
+ else:
28
+ return p.to(**kwargs)
29
+ if '_buffers' in self.__dict__:
30
+ _buffers = self.__dict__['_buffers']
31
+ if name in _buffers:
32
+ return _buffers[name].to(**kwargs)
33
+ return super(original_class, self).__getattr__(name)
34
+
35
+ module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
36
+ '__getattr__': hacked_get_attr,
37
+ })
38
+
39
+ return
40
+
41
+ @staticmethod
42
+ def _uninstall_module(module: torch.nn.Module):
43
+ if 'forge_backup_original_class' in module.__dict__:
44
+ module.__class__ = module.__dict__.pop('forge_backup_original_class')
45
+ return
46
+
47
+ @staticmethod
48
+ def install_model(model: torch.nn.Module, **kwargs):
49
+ for m in model.modules():
50
+ DynamicSwapInstaller._install_module(m, **kwargs)
51
+ return
52
+
53
+ @staticmethod
54
+ def uninstall_model(model: torch.nn.Module):
55
+ for m in model.modules():
56
+ DynamicSwapInstaller._uninstall_module(m)
57
+ return
58
+
59
+
60
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
61
+ if hasattr(model, 'scale_shift_table'):
62
+ model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
63
+ return
64
+
65
+ for k, p in model.named_modules():
66
+ if hasattr(p, 'weight'):
67
+ p.to(target_device)
68
+ return
69
+
70
+
71
+ def get_cuda_free_memory_gb(device=None):
72
+ if device is None:
73
+ device = gpu
74
+
75
+ memory_stats = torch.cuda.memory_stats(device)
76
+ bytes_active = memory_stats['active_bytes.all.current']
77
+ bytes_reserved = memory_stats['reserved_bytes.all.current']
78
+ bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
79
+ bytes_inactive_reserved = bytes_reserved - bytes_active
80
+ bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
81
+ return bytes_total_available / (1024 ** 3)
82
+
83
+
84
+ def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
85
+ print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
86
+
87
+ for m in model.modules():
88
+ if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
89
+ torch.cuda.empty_cache()
90
+ return
91
+
92
+ if hasattr(m, 'weight'):
93
+ m.to(device=target_device)
94
+
95
+ model.to(device=target_device)
96
+ torch.cuda.empty_cache()
97
+ return
98
+
99
+
100
+ def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
101
+ print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
102
+
103
+ for m in model.modules():
104
+ if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
105
+ torch.cuda.empty_cache()
106
+ return
107
+
108
+ if hasattr(m, 'weight'):
109
+ m.to(device=cpu)
110
+
111
+ model.to(device=cpu)
112
+ torch.cuda.empty_cache()
113
+ return
114
+
115
+
116
+ def unload_complete_models(*args):
117
+ for m in gpu_complete_modules + list(args):
118
+ m.to(device=cpu)
119
+ print(f'Unloaded {m.__class__.__name__} as complete.')
120
+
121
+ gpu_complete_modules.clear()
122
+ torch.cuda.empty_cache()
123
+ return
124
+
125
+
126
+ def load_model_as_complete(model, target_device, unload=True):
127
+ if unload:
128
+ unload_complete_models()
129
+
130
+ model.to(device=target_device)
131
+ print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
132
+
133
+ gpu_complete_modules.append(model)
134
+ return
diffusers_helper/models/hunyuan_video_packed.py ADDED
@@ -0,0 +1,1062 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import einops
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ from diffusers.loaders import FromOriginalModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.loaders import PeftAdapterMixin
11
+ from diffusers.utils import logging
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import Attention
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers_helper.dit_common import LayerNorm
18
+ from diffusers_helper.models.mag_cache import MagCache
19
+ from diffusers_helper.utils import zero_module
20
+
21
+
22
+ enabled_backends = []
23
+
24
+ if torch.backends.cuda.flash_sdp_enabled():
25
+ enabled_backends.append("flash")
26
+ if torch.backends.cuda.math_sdp_enabled():
27
+ enabled_backends.append("math")
28
+ if torch.backends.cuda.mem_efficient_sdp_enabled():
29
+ enabled_backends.append("mem_efficient")
30
+ if torch.backends.cuda.cudnn_sdp_enabled():
31
+ enabled_backends.append("cudnn")
32
+
33
+ print("Currently enabled native sdp backends:", enabled_backends)
34
+
35
+ xformers_attn_func = None
36
+ flash_attn_varlen_func = None
37
+ flash_attn_func = None
38
+ sageattn_varlen = None
39
+ sageattn = None
40
+
41
+ try:
42
+ # raise NotImplementedError
43
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
44
+ except:
45
+ pass
46
+
47
+ try:
48
+ # raise NotImplementedError
49
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
50
+ except:
51
+ pass
52
+
53
+ try:
54
+ # raise NotImplementedError
55
+ from sageattention import sageattn_varlen, sageattn
56
+ except:
57
+ pass
58
+
59
+ # --- Attention Summary ---
60
+ print("\n--- Attention Configuration ---")
61
+ has_sage = sageattn is not None and sageattn_varlen is not None
62
+ has_flash = flash_attn_func is not None and flash_attn_varlen_func is not None
63
+ has_xformers = xformers_attn_func is not None
64
+
65
+ if has_sage:
66
+ print("✅ Using SAGE Attention (highest performance).")
67
+ ignored = []
68
+ if has_flash:
69
+ ignored.append("Flash Attention")
70
+ if has_xformers:
71
+ ignored.append("xFormers")
72
+ if ignored:
73
+ print(f" - Ignoring other installed attention libraries: {', '.join(ignored)}")
74
+ elif has_flash:
75
+ print("✅ Using Flash Attention (high performance).")
76
+ if has_xformers:
77
+ print(" - Consider installing SAGE Attention for highest performance.")
78
+ print(" - Ignoring other installed attention library: xFormers")
79
+ elif has_xformers:
80
+ print("✅ Using xFormers.")
81
+ print(" - Consider installing SAGE Attention for highest performance.")
82
+ print(" - or Consider installing Flash Attention for high performance.")
83
+ else:
84
+ print("⚠️ No attention library found. Using native PyTorch Scaled Dot Product Attention.")
85
+ print(" - For better performance, consider installing one of:")
86
+ print(" SAGE Attention (highest performance), Flash Attention (high performance), or xFormers.")
87
+ print("-------------------------------\n")
88
+
89
+
90
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
91
+
92
+
93
+ def pad_for_3d_conv(x, kernel_size):
94
+ b, c, t, h, w = x.shape
95
+ pt, ph, pw = kernel_size
96
+ pad_t = (pt - (t % pt)) % pt
97
+ pad_h = (ph - (h % ph)) % ph
98
+ pad_w = (pw - (w % pw)) % pw
99
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate')
100
+
101
+
102
+ def center_down_sample_3d(x, kernel_size):
103
+ # pt, ph, pw = kernel_size
104
+ # cp = (pt * ph * pw) // 2
105
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
106
+ # xc = xp[cp]
107
+ # return xc
108
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
109
+
110
+
111
+ def get_cu_seqlens(text_mask, img_len):
112
+ batch_size = text_mask.shape[0]
113
+ text_len = text_mask.sum(dim=1)
114
+ max_len = text_mask.shape[1] + img_len
115
+
116
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
117
+
118
+ for i in range(batch_size):
119
+ s = text_len[i] + img_len
120
+ s1 = i * max_len + s
121
+ s2 = (i + 1) * max_len
122
+ cu_seqlens[2 * i + 1] = s1
123
+ cu_seqlens[2 * i + 2] = s2
124
+
125
+ return cu_seqlens
126
+
127
+
128
+ def apply_rotary_emb_transposed(x, freqs_cis):
129
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
130
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
131
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
132
+ out = x.float() * cos + x_rotated.float() * sin
133
+ out = out.to(x)
134
+ return out
135
+
136
+
137
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv):
138
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
139
+ if sageattn is not None:
140
+ x = sageattn(q, k, v, tensor_layout='NHD')
141
+ return x
142
+
143
+ if flash_attn_func is not None:
144
+ x = flash_attn_func(q, k, v)
145
+ return x
146
+
147
+ if xformers_attn_func is not None:
148
+ x = xformers_attn_func(q, k, v)
149
+ return x
150
+
151
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2)
152
+ return x
153
+
154
+ batch_size = q.shape[0]
155
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
156
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
157
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
158
+ if sageattn_varlen is not None:
159
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
160
+ elif flash_attn_varlen_func is not None:
161
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
162
+ else:
163
+ raise NotImplementedError('No Attn Installed!')
164
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
165
+ return x
166
+
167
+
168
+ class HunyuanAttnProcessorFlashAttnDouble:
169
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
170
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
171
+
172
+ query = attn.to_q(hidden_states)
173
+ key = attn.to_k(hidden_states)
174
+ value = attn.to_v(hidden_states)
175
+
176
+ query = query.unflatten(2, (attn.heads, -1))
177
+ key = key.unflatten(2, (attn.heads, -1))
178
+ value = value.unflatten(2, (attn.heads, -1))
179
+
180
+ query = attn.norm_q(query)
181
+ key = attn.norm_k(key)
182
+
183
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
184
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
185
+
186
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
187
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
188
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
189
+
190
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
191
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
192
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
193
+
194
+ encoder_query = attn.norm_added_q(encoder_query)
195
+ encoder_key = attn.norm_added_k(encoder_key)
196
+
197
+ query = torch.cat([query, encoder_query], dim=1)
198
+ key = torch.cat([key, encoder_key], dim=1)
199
+ value = torch.cat([value, encoder_value], dim=1)
200
+
201
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
202
+ hidden_states = hidden_states.flatten(-2)
203
+
204
+ txt_length = encoder_hidden_states.shape[1]
205
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
206
+
207
+ hidden_states = attn.to_out[0](hidden_states)
208
+ hidden_states = attn.to_out[1](hidden_states)
209
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
210
+
211
+ return hidden_states, encoder_hidden_states
212
+
213
+
214
+ class HunyuanAttnProcessorFlashAttnSingle:
215
+ def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb):
216
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
217
+
218
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
219
+
220
+ query = attn.to_q(hidden_states)
221
+ key = attn.to_k(hidden_states)
222
+ value = attn.to_v(hidden_states)
223
+
224
+ query = query.unflatten(2, (attn.heads, -1))
225
+ key = key.unflatten(2, (attn.heads, -1))
226
+ value = value.unflatten(2, (attn.heads, -1))
227
+
228
+ query = attn.norm_q(query)
229
+ key = attn.norm_k(key)
230
+
231
+ txt_length = encoder_hidden_states.shape[1]
232
+
233
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
234
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
235
+
236
+ hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
237
+ hidden_states = hidden_states.flatten(-2)
238
+
239
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
240
+
241
+ return hidden_states, encoder_hidden_states
242
+
243
+
244
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
245
+ def __init__(self, embedding_dim, pooled_projection_dim):
246
+ super().__init__()
247
+
248
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
249
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
250
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
251
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
252
+
253
+ def forward(self, timestep, guidance, pooled_projection):
254
+ timesteps_proj = self.time_proj(timestep)
255
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
256
+
257
+ guidance_proj = self.time_proj(guidance)
258
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
259
+
260
+ time_guidance_emb = timesteps_emb + guidance_emb
261
+
262
+ pooled_projections = self.text_embedder(pooled_projection)
263
+ conditioning = time_guidance_emb + pooled_projections
264
+
265
+ return conditioning
266
+
267
+
268
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
269
+ def __init__(self, embedding_dim, pooled_projection_dim):
270
+ super().__init__()
271
+
272
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
273
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
274
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
275
+
276
+ def forward(self, timestep, pooled_projection):
277
+ timesteps_proj = self.time_proj(timestep)
278
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
279
+
280
+ pooled_projections = self.text_embedder(pooled_projection)
281
+
282
+ conditioning = timesteps_emb + pooled_projections
283
+
284
+ return conditioning
285
+
286
+
287
+ class HunyuanVideoAdaNorm(nn.Module):
288
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
289
+ super().__init__()
290
+
291
+ out_features = out_features or 2 * in_features
292
+ self.linear = nn.Linear(in_features, out_features)
293
+ self.nonlinearity = nn.SiLU()
294
+
295
+ def forward(
296
+ self, temb: torch.Tensor
297
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
298
+ temb = self.linear(self.nonlinearity(temb))
299
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
300
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
301
+ return gate_msa, gate_mlp
302
+
303
+
304
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
305
+ def __init__(
306
+ self,
307
+ num_attention_heads: int,
308
+ attention_head_dim: int,
309
+ mlp_width_ratio: str = 4.0,
310
+ mlp_drop_rate: float = 0.0,
311
+ attention_bias: bool = True,
312
+ ) -> None:
313
+ super().__init__()
314
+
315
+ hidden_size = num_attention_heads * attention_head_dim
316
+
317
+ self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
318
+ self.attn = Attention(
319
+ query_dim=hidden_size,
320
+ cross_attention_dim=None,
321
+ heads=num_attention_heads,
322
+ dim_head=attention_head_dim,
323
+ bias=attention_bias,
324
+ )
325
+
326
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
327
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
328
+
329
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
330
+
331
+ def forward(
332
+ self,
333
+ hidden_states: torch.Tensor,
334
+ temb: torch.Tensor,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ ) -> torch.Tensor:
337
+ norm_hidden_states = self.norm1(hidden_states)
338
+
339
+ attn_output = self.attn(
340
+ hidden_states=norm_hidden_states,
341
+ encoder_hidden_states=None,
342
+ attention_mask=attention_mask,
343
+ )
344
+
345
+ gate_msa, gate_mlp = self.norm_out(temb)
346
+ hidden_states = hidden_states + attn_output * gate_msa
347
+
348
+ ff_output = self.ff(self.norm2(hidden_states))
349
+ hidden_states = hidden_states + ff_output * gate_mlp
350
+
351
+ return hidden_states
352
+
353
+
354
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
355
+ def __init__(
356
+ self,
357
+ num_attention_heads: int,
358
+ attention_head_dim: int,
359
+ num_layers: int,
360
+ mlp_width_ratio: float = 4.0,
361
+ mlp_drop_rate: float = 0.0,
362
+ attention_bias: bool = True,
363
+ ) -> None:
364
+ super().__init__()
365
+
366
+ self.refiner_blocks = nn.ModuleList(
367
+ [
368
+ HunyuanVideoIndividualTokenRefinerBlock(
369
+ num_attention_heads=num_attention_heads,
370
+ attention_head_dim=attention_head_dim,
371
+ mlp_width_ratio=mlp_width_ratio,
372
+ mlp_drop_rate=mlp_drop_rate,
373
+ attention_bias=attention_bias,
374
+ )
375
+ for _ in range(num_layers)
376
+ ]
377
+ )
378
+
379
+ def forward(
380
+ self,
381
+ hidden_states: torch.Tensor,
382
+ temb: torch.Tensor,
383
+ attention_mask: Optional[torch.Tensor] = None,
384
+ ) -> None:
385
+ self_attn_mask = None
386
+ if attention_mask is not None:
387
+ batch_size = attention_mask.shape[0]
388
+ seq_len = attention_mask.shape[1]
389
+ attention_mask = attention_mask.to(hidden_states.device).bool()
390
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
391
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
392
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
393
+ self_attn_mask[:, :, :, 0] = True
394
+
395
+ for block in self.refiner_blocks:
396
+ hidden_states = block(hidden_states, temb, self_attn_mask)
397
+
398
+ return hidden_states
399
+
400
+
401
+ class HunyuanVideoTokenRefiner(nn.Module):
402
+ def __init__(
403
+ self,
404
+ in_channels: int,
405
+ num_attention_heads: int,
406
+ attention_head_dim: int,
407
+ num_layers: int,
408
+ mlp_ratio: float = 4.0,
409
+ mlp_drop_rate: float = 0.0,
410
+ attention_bias: bool = True,
411
+ ) -> None:
412
+ super().__init__()
413
+
414
+ hidden_size = num_attention_heads * attention_head_dim
415
+
416
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
417
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
418
+ )
419
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
420
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
421
+ num_attention_heads=num_attention_heads,
422
+ attention_head_dim=attention_head_dim,
423
+ num_layers=num_layers,
424
+ mlp_width_ratio=mlp_ratio,
425
+ mlp_drop_rate=mlp_drop_rate,
426
+ attention_bias=attention_bias,
427
+ )
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ timestep: torch.LongTensor,
433
+ attention_mask: Optional[torch.LongTensor] = None,
434
+ ) -> torch.Tensor:
435
+ if attention_mask is None:
436
+ pooled_projections = hidden_states.mean(dim=1)
437
+ else:
438
+ original_dtype = hidden_states.dtype
439
+ mask_float = attention_mask.float().unsqueeze(-1)
440
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
441
+ pooled_projections = pooled_projections.to(original_dtype)
442
+
443
+ temb = self.time_text_embed(timestep, pooled_projections)
444
+ hidden_states = self.proj_in(hidden_states)
445
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
446
+
447
+ return hidden_states
448
+
449
+
450
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
451
+ def __init__(self, rope_dim, theta):
452
+ super().__init__()
453
+ self.DT, self.DY, self.DX = rope_dim
454
+ self.theta = theta
455
+
456
+ @torch.no_grad()
457
+ def get_frequency(self, dim, pos):
458
+ T, H, W = pos.shape
459
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
460
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
461
+ return freqs.cos(), freqs.sin()
462
+
463
+ @torch.no_grad()
464
+ def forward_inner(self, frame_indices, height, width, device):
465
+ GT, GY, GX = torch.meshgrid(
466
+ frame_indices.to(device=device, dtype=torch.float32),
467
+ torch.arange(0, height, device=device, dtype=torch.float32),
468
+ torch.arange(0, width, device=device, dtype=torch.float32),
469
+ indexing="ij"
470
+ )
471
+
472
+ FCT, FST = self.get_frequency(self.DT, GT)
473
+ FCY, FSY = self.get_frequency(self.DY, GY)
474
+ FCX, FSX = self.get_frequency(self.DX, GX)
475
+
476
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
477
+
478
+ return result.to(device)
479
+
480
+ @torch.no_grad()
481
+ def forward(self, frame_indices, height, width, device):
482
+ frame_indices = frame_indices.unbind(0)
483
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
484
+ results = torch.stack(results, dim=0)
485
+ return results
486
+
487
+
488
+ class AdaLayerNormZero(nn.Module):
489
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
490
+ super().__init__()
491
+ self.silu = nn.SiLU()
492
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
493
+ if norm_type == "layer_norm":
494
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
495
+ else:
496
+ raise ValueError(f"unknown norm_type {norm_type}")
497
+
498
+ def forward(
499
+ self,
500
+ x: torch.Tensor,
501
+ emb: Optional[torch.Tensor] = None,
502
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
503
+ emb = emb.unsqueeze(-2)
504
+ emb = self.linear(self.silu(emb))
505
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
506
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
507
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
508
+
509
+
510
+ class AdaLayerNormZeroSingle(nn.Module):
511
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
512
+ super().__init__()
513
+
514
+ self.silu = nn.SiLU()
515
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
516
+ if norm_type == "layer_norm":
517
+ self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
518
+ else:
519
+ raise ValueError(f"unknown norm_type {norm_type}")
520
+
521
+ def forward(
522
+ self,
523
+ x: torch.Tensor,
524
+ emb: Optional[torch.Tensor] = None,
525
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
526
+ emb = emb.unsqueeze(-2)
527
+ emb = self.linear(self.silu(emb))
528
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
529
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
530
+ return x, gate_msa
531
+
532
+
533
+ class AdaLayerNormContinuous(nn.Module):
534
+ def __init__(
535
+ self,
536
+ embedding_dim: int,
537
+ conditioning_embedding_dim: int,
538
+ elementwise_affine=True,
539
+ eps=1e-5,
540
+ bias=True,
541
+ norm_type="layer_norm",
542
+ ):
543
+ super().__init__()
544
+ self.silu = nn.SiLU()
545
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
546
+ if norm_type == "layer_norm":
547
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
548
+ else:
549
+ raise ValueError(f"unknown norm_type {norm_type}")
550
+
551
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
552
+ emb = emb.unsqueeze(-2)
553
+ emb = self.linear(self.silu(emb))
554
+ scale, shift = emb.chunk(2, dim=-1)
555
+ x = self.norm(x) * (1 + scale) + shift
556
+ return x
557
+
558
+
559
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
560
+ def __init__(
561
+ self,
562
+ num_attention_heads: int,
563
+ attention_head_dim: int,
564
+ mlp_ratio: float = 4.0,
565
+ qk_norm: str = "rms_norm",
566
+ ) -> None:
567
+ super().__init__()
568
+
569
+ hidden_size = num_attention_heads * attention_head_dim
570
+ mlp_dim = int(hidden_size * mlp_ratio)
571
+
572
+ self.attn = Attention(
573
+ query_dim=hidden_size,
574
+ cross_attention_dim=None,
575
+ dim_head=attention_head_dim,
576
+ heads=num_attention_heads,
577
+ out_dim=hidden_size,
578
+ bias=True,
579
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
580
+ qk_norm=qk_norm,
581
+ eps=1e-6,
582
+ pre_only=True,
583
+ )
584
+
585
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
586
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
587
+ self.act_mlp = nn.GELU(approximate="tanh")
588
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
589
+
590
+ def forward(
591
+ self,
592
+ hidden_states: torch.Tensor,
593
+ encoder_hidden_states: torch.Tensor,
594
+ temb: torch.Tensor,
595
+ attention_mask: Optional[torch.Tensor] = None,
596
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
597
+ ) -> torch.Tensor:
598
+ text_seq_length = encoder_hidden_states.shape[1]
599
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
600
+
601
+ residual = hidden_states
602
+
603
+ # 1. Input normalization
604
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
605
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
606
+
607
+ norm_hidden_states, norm_encoder_hidden_states = (
608
+ norm_hidden_states[:, :-text_seq_length, :],
609
+ norm_hidden_states[:, -text_seq_length:, :],
610
+ )
611
+
612
+ # 2. Attention
613
+ attn_output, context_attn_output = self.attn(
614
+ hidden_states=norm_hidden_states,
615
+ encoder_hidden_states=norm_encoder_hidden_states,
616
+ attention_mask=attention_mask,
617
+ image_rotary_emb=image_rotary_emb,
618
+ )
619
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
620
+
621
+ # 3. Modulation and residual connection
622
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
623
+ hidden_states = gate * self.proj_out(hidden_states)
624
+ hidden_states = hidden_states + residual
625
+
626
+ hidden_states, encoder_hidden_states = (
627
+ hidden_states[:, :-text_seq_length, :],
628
+ hidden_states[:, -text_seq_length:, :],
629
+ )
630
+ return hidden_states, encoder_hidden_states
631
+
632
+
633
+ class HunyuanVideoTransformerBlock(nn.Module):
634
+ def __init__(
635
+ self,
636
+ num_attention_heads: int,
637
+ attention_head_dim: int,
638
+ mlp_ratio: float,
639
+ qk_norm: str = "rms_norm",
640
+ ) -> None:
641
+ super().__init__()
642
+
643
+ hidden_size = num_attention_heads * attention_head_dim
644
+
645
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
646
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
647
+
648
+ self.attn = Attention(
649
+ query_dim=hidden_size,
650
+ cross_attention_dim=None,
651
+ added_kv_proj_dim=hidden_size,
652
+ dim_head=attention_head_dim,
653
+ heads=num_attention_heads,
654
+ out_dim=hidden_size,
655
+ context_pre_only=False,
656
+ bias=True,
657
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
658
+ qk_norm=qk_norm,
659
+ eps=1e-6,
660
+ )
661
+
662
+ self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
663
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
664
+
665
+ self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
666
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
667
+
668
+ def forward(
669
+ self,
670
+ hidden_states: torch.Tensor,
671
+ encoder_hidden_states: torch.Tensor,
672
+ temb: torch.Tensor,
673
+ attention_mask: Optional[torch.Tensor] = None,
674
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
675
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
676
+ # 1. Input normalization
677
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
678
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb)
679
+
680
+ # 2. Joint attention
681
+ attn_output, context_attn_output = self.attn(
682
+ hidden_states=norm_hidden_states,
683
+ encoder_hidden_states=norm_encoder_hidden_states,
684
+ attention_mask=attention_mask,
685
+ image_rotary_emb=freqs_cis,
686
+ )
687
+
688
+ # 3. Modulation and residual connection
689
+ hidden_states = hidden_states + attn_output * gate_msa
690
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
691
+
692
+ norm_hidden_states = self.norm2(hidden_states)
693
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
694
+
695
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
696
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
697
+
698
+ # 4. Feed-forward
699
+ ff_output = self.ff(norm_hidden_states)
700
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
701
+
702
+ hidden_states = hidden_states + gate_mlp * ff_output
703
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
704
+
705
+ return hidden_states, encoder_hidden_states
706
+
707
+
708
+ class ClipVisionProjection(nn.Module):
709
+ def __init__(self, in_channels, out_channels):
710
+ super().__init__()
711
+ self.up = nn.Linear(in_channels, out_channels * 3)
712
+ self.down = nn.Linear(out_channels * 3, out_channels)
713
+
714
+ def forward(self, x):
715
+ projected_x = self.down(nn.functional.silu(self.up(x)))
716
+ return projected_x
717
+
718
+
719
+ class HunyuanVideoPatchEmbed(nn.Module):
720
+ def __init__(self, patch_size, in_chans, embed_dim):
721
+ super().__init__()
722
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
723
+
724
+
725
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
726
+ def __init__(self, inner_dim):
727
+ super().__init__()
728
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
729
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
730
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
731
+
732
+ @torch.no_grad()
733
+ def initialize_weight_from_another_conv3d(self, another_layer):
734
+ weight = another_layer.weight.detach().clone()
735
+ bias = another_layer.bias.detach().clone()
736
+
737
+ sd = {
738
+ 'proj.weight': weight.clone(),
739
+ 'proj.bias': bias.clone(),
740
+ 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0,
741
+ 'proj_2x.bias': bias.clone(),
742
+ 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0,
743
+ 'proj_4x.bias': bias.clone(),
744
+ }
745
+
746
+ sd = {k: v.clone() for k, v in sd.items()}
747
+
748
+ self.load_state_dict(sd)
749
+ return
750
+
751
+
752
+ class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
753
+ @register_to_config
754
+ def __init__(
755
+ self,
756
+ in_channels: int = 16,
757
+ out_channels: int = 16,
758
+ num_attention_heads: int = 24,
759
+ attention_head_dim: int = 128,
760
+ num_layers: int = 20,
761
+ num_single_layers: int = 40,
762
+ num_refiner_layers: int = 2,
763
+ mlp_ratio: float = 4.0,
764
+ patch_size: int = 2,
765
+ patch_size_t: int = 1,
766
+ qk_norm: str = "rms_norm",
767
+ guidance_embeds: bool = True,
768
+ text_embed_dim: int = 4096,
769
+ pooled_projection_dim: int = 768,
770
+ rope_theta: float = 256.0,
771
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
772
+ has_image_proj=False,
773
+ image_proj_dim=1152,
774
+ has_clean_x_embedder=False,
775
+ ) -> None:
776
+ super().__init__()
777
+
778
+ inner_dim = num_attention_heads * attention_head_dim
779
+ out_channels = out_channels or in_channels
780
+
781
+ # 1. Latent and condition embedders
782
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
783
+ self.context_embedder = HunyuanVideoTokenRefiner(
784
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
785
+ )
786
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
787
+
788
+ self.clean_x_embedder = None
789
+ self.image_projection = None
790
+
791
+ # 2. RoPE
792
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
793
+
794
+ # 3. Dual stream transformer blocks
795
+ self.transformer_blocks = nn.ModuleList(
796
+ [
797
+ HunyuanVideoTransformerBlock(
798
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
799
+ )
800
+ for _ in range(num_layers)
801
+ ]
802
+ )
803
+
804
+ # 4. Single stream transformer blocks
805
+ self.single_transformer_blocks = nn.ModuleList(
806
+ [
807
+ HunyuanVideoSingleTransformerBlock(
808
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
809
+ )
810
+ for _ in range(num_single_layers)
811
+ ]
812
+ )
813
+
814
+ # 5. Output projection
815
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
816
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
817
+
818
+ self.inner_dim = inner_dim
819
+ self.use_gradient_checkpointing = False
820
+ self.enable_teacache = False
821
+ self.magcache: MagCache = None
822
+
823
+ if has_image_proj:
824
+ self.install_image_projection(image_proj_dim)
825
+
826
+ if has_clean_x_embedder:
827
+ self.install_clean_x_embedder()
828
+
829
+ self.high_quality_fp32_output_for_inference = False
830
+
831
+ def install_image_projection(self, in_channels):
832
+ self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim)
833
+ self.config['has_image_proj'] = True
834
+ self.config['image_proj_dim'] = in_channels
835
+
836
+ def install_clean_x_embedder(self):
837
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
838
+ self.config['has_clean_x_embedder'] = True
839
+
840
+ def enable_gradient_checkpointing(self):
841
+ self.use_gradient_checkpointing = True
842
+ print('self.use_gradient_checkpointing = True')
843
+
844
+ def disable_gradient_checkpointing(self):
845
+ self.use_gradient_checkpointing = False
846
+ print('self.use_gradient_checkpointing = False')
847
+
848
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
849
+ self.enable_teacache = enable_teacache
850
+ self.cnt = 0
851
+ self.num_steps = num_steps
852
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
853
+ self.accumulated_rel_l1_distance = 0
854
+ self.previous_modulated_input = None
855
+ self.previous_residual = None
856
+ self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02])
857
+
858
+ def install_magcache(self, magcache: MagCache):
859
+ self.magcache = magcache
860
+
861
+ def uninstall_magcache(self):
862
+ self.magcache = None
863
+
864
+ def gradient_checkpointing_method(self, block, *args):
865
+ if self.use_gradient_checkpointing:
866
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
867
+ else:
868
+ result = block(*args)
869
+ return result
870
+
871
+ def process_input_hidden_states(
872
+ self,
873
+ latents, latent_indices=None,
874
+ clean_latents=None, clean_latent_indices=None,
875
+ clean_latents_2x=None, clean_latent_2x_indices=None,
876
+ clean_latents_4x=None, clean_latent_4x_indices=None
877
+ ):
878
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
879
+ B, C, T, H, W = hidden_states.shape
880
+
881
+ if latent_indices is None:
882
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
883
+
884
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
885
+
886
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
887
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
888
+
889
+ if clean_latents is not None and clean_latent_indices is not None:
890
+ clean_latents = clean_latents.to(hidden_states)
891
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
892
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
893
+
894
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
895
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
896
+
897
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
898
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
899
+
900
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
901
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
902
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
903
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
904
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
905
+
906
+ clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device)
907
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
908
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
909
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
910
+
911
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
912
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
913
+
914
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
915
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
916
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
917
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
918
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
919
+
920
+ clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device)
921
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
922
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
923
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
924
+
925
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
926
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
927
+
928
+ return hidden_states, rope_freqs
929
+
930
+ def forward(
931
+ self,
932
+ hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance,
933
+ latent_indices=None,
934
+ clean_latents=None, clean_latent_indices=None,
935
+ clean_latents_2x=None, clean_latent_2x_indices=None,
936
+ clean_latents_4x=None, clean_latent_4x_indices=None,
937
+ image_embeddings=None,
938
+ attention_kwargs=None, return_dict=True
939
+ ):
940
+
941
+ if attention_kwargs is None:
942
+ attention_kwargs = {}
943
+
944
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
945
+ p, p_t = self.config['patch_size'], self.config['patch_size_t']
946
+ post_patch_num_frames = num_frames // p_t
947
+ post_patch_height = height // p
948
+ post_patch_width = width // p
949
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
950
+
951
+ hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices)
952
+
953
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
954
+ encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask)
955
+
956
+ if self.image_projection is not None:
957
+ assert image_embeddings is not None, 'You must use image embeddings!'
958
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
959
+ extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device)
960
+
961
+ # must cat before (not after) encoder_hidden_states, due to attn masking
962
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
963
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
964
+
965
+ with torch.no_grad():
966
+ if batch_size == 1:
967
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
968
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
969
+ text_len = encoder_attention_mask.sum().item()
970
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
971
+ attention_mask = None, None, None, None
972
+ else:
973
+ img_seq_len = hidden_states.shape[1]
974
+ txt_seq_len = encoder_hidden_states.shape[1]
975
+
976
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
977
+ cu_seqlens_kv = cu_seqlens_q
978
+ max_seqlen_q = img_seq_len + txt_seq_len
979
+ max_seqlen_kv = max_seqlen_q
980
+
981
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
982
+
983
+ if self.enable_teacache:
984
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
985
+
986
+ if self.cnt == 0 or self.cnt == self.num_steps-1:
987
+ should_calc = True
988
+ self.accumulated_rel_l1_distance = 0
989
+ else:
990
+ curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()
991
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
992
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
993
+
994
+ if should_calc:
995
+ self.accumulated_rel_l1_distance = 0
996
+
997
+ self.previous_modulated_input = modulated_inp
998
+ self.cnt += 1
999
+
1000
+ if self.cnt == self.num_steps:
1001
+ self.cnt = 0
1002
+
1003
+ if not should_calc:
1004
+ hidden_states = hidden_states + self.previous_residual
1005
+ else:
1006
+ ori_hidden_states = hidden_states.clone()
1007
+
1008
+ hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs)
1009
+
1010
+ self.previous_residual = hidden_states - ori_hidden_states
1011
+
1012
+ elif self.magcache and self.magcache.is_enabled:
1013
+ if self.magcache.should_skip(hidden_states):
1014
+ hidden_states = self.magcache.estimate_predicted_hidden_states()
1015
+ else:
1016
+ hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs)
1017
+ self.magcache.update_hidden_states(model_prediction_hidden_states=hidden_states)
1018
+
1019
+ else:
1020
+ hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs)
1021
+
1022
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1023
+
1024
+ hidden_states = hidden_states[:, -original_context_length:, :]
1025
+
1026
+ if self.high_quality_fp32_output_for_inference:
1027
+ hidden_states = hidden_states.to(dtype=torch.float32)
1028
+ if self.proj_out.weight.dtype != torch.float32:
1029
+ self.proj_out.to(dtype=torch.float32)
1030
+
1031
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1032
+
1033
+ hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)',
1034
+ t=post_patch_num_frames, h=post_patch_height, w=post_patch_width,
1035
+ pt=p_t, ph=p, pw=p)
1036
+
1037
+ if return_dict:
1038
+ return Transformer2DModelOutput(sample=hidden_states)
1039
+
1040
+ return hidden_states,
1041
+
1042
+ def _run_denoising_layers(
1043
+ self,
1044
+ hidden_states: torch.Tensor,
1045
+ encoder_hidden_states: torch.Tensor,
1046
+ temb: torch.Tensor,
1047
+ attention_mask: Optional[Tuple],
1048
+ rope_freqs: Optional[torch.Tensor]
1049
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1050
+ """
1051
+ Applies the dual-stream and single-stream transformer blocks.
1052
+ """
1053
+ for block_id, block in enumerate(self.transformer_blocks):
1054
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1055
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1056
+ )
1057
+
1058
+ for block_id, block in enumerate(self.single_transformer_blocks):
1059
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1060
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1061
+ )
1062
+ return hidden_states, encoder_hidden_states
diffusers_helper/models/mag_cache.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+
5
+ from diffusers_helper.models.mag_cache_ratios import MAG_RATIOS_DB
6
+
7
+
8
+ class MagCache:
9
+ """
10
+ Implements the MagCache algorithm for skipping transformer steps during video generation.
11
+ MagCache: Fast Video Generation with Magnitude-Aware Cache
12
+ Zehong Ma, Longhui Wei, Feng Wang, Shiliang Zhang, Qi Tian
13
+ https://arxiv.org/abs/2506.09045
14
+ https://github.com/Zehong-Ma/MagCache
15
+ PR Demo defaults were threshold=0.1, max_consectutive_skips=3, retention_ratio=0.2
16
+ Changing defauults to threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25 for quality vs speed tradeoff.
17
+ """
18
+
19
+ def __init__(self, model_family, height, width, num_steps, is_enabled=True, is_calibrating = False, threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25):
20
+ self.model_family = model_family
21
+ self.height = height
22
+ self.width = width
23
+ self.num_steps = num_steps
24
+
25
+ self.is_enabled = is_enabled
26
+ self.is_calibrating = is_calibrating
27
+
28
+ self.threshold = threshold
29
+ self.max_consectutive_skips = max_consectutive_skips
30
+ self.retention_ratio = retention_ratio
31
+
32
+ # total cache statistics for all sections in the entire generation
33
+ self.total_cache_requests = 0
34
+ self.total_cache_hits = 0
35
+
36
+ self.mag_ratios = self._determine_mag_ratios()
37
+
38
+ self._init_for_every_section()
39
+
40
+
41
+ def _init_for_every_section(self):
42
+ self.step_index = 0
43
+ self.steps_skipped_list = []
44
+ #Error accumulation state
45
+ self.accumulated_ratio = 1.0
46
+ self.accumulated_steps = 0
47
+ self.accumulated_err = 0
48
+ # Statistics for calibration
49
+ self.norm_ratio, self.norm_std, self.cos_dis = [], [], []
50
+
51
+ self.hidden_states = None
52
+ self.previous_residual = None
53
+
54
+ if self.is_calibrating and self.total_cache_requests > 0:
55
+ print('WARNING: Resetting MagCache calibration stats for new section. Typically you only want one section per calibration job. Discarding calibration from previsou section.')
56
+
57
+ def should_skip(self, hidden_states):
58
+ """
59
+ Expected to be called once per step during the forward pass, for the numer of initialized steps.
60
+ Determines if the current step should be skipped based on estimated accumulated error.
61
+ If the step is skipped, the hidden_states should be replaced with the output of estimate_predicted_hidden_states().
62
+
63
+ Args:
64
+ hidden_states: The current hidden states tensor from the transformer model.
65
+ Returns:
66
+ True if the step should be skipped, False otherwise
67
+ """
68
+ if self.step_index == 0 or self.step_index >= self.num_steps:
69
+ self._init_for_every_section()
70
+ self.total_cache_requests += 1
71
+ self.hidden_states = hidden_states.clone() # Is clone needed?
72
+
73
+ if self.is_calibrating:
74
+ print('######################### Calibrating MagCache #########################')
75
+ return False
76
+
77
+ should_skip_forward = False
78
+ if self.step_index>=int(self.retention_ratio*self.num_steps) and self.step_index>=1: # keep first retention_ratio steps
79
+ cur_mag_ratio = self.mag_ratios[self.step_index]
80
+ self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio
81
+ cur_skip_err = np.abs(1-self.accumulated_ratio)
82
+ self.accumulated_err += cur_skip_err
83
+ self.accumulated_steps += 1
84
+ # RT_BORG: Per my conversation with Zehong Ma, this 0.06 could potentially be exposed as another tunable param.
85
+ if self.accumulated_err<=self.threshold and self.accumulated_steps<=self.max_consectutive_skips and np.abs(1-cur_mag_ratio)<=0.06:
86
+ should_skip_forward = True
87
+ else:
88
+ self.accumulated_ratio = 1.0
89
+ self.accumulated_steps = 0
90
+ self.accumulated_err = 0
91
+
92
+ if should_skip_forward:
93
+ self.total_cache_hits += 1
94
+ self.steps_skipped_list.append(self.step_index)
95
+ # Increment for next step
96
+ self.step_index += 1
97
+ if self.step_index == self.num_steps:
98
+ self.step_index = 0
99
+
100
+ return should_skip_forward
101
+
102
+ def estimate_predicted_hidden_states(self):
103
+ """
104
+ Should be called if and only if should_skip() returned True for the current step.
105
+ Estimates the hidden states for the current step based on the previous hidden states and residual.
106
+
107
+ Returns:
108
+ The estimated hidden states tensor.
109
+ """
110
+ return self.hidden_states + self.previous_residual
111
+
112
+ def update_hidden_states(self, model_prediction_hidden_states):
113
+ """
114
+ If and only if should_skip() returned False for the current step, the denoising layers should have been run,
115
+ and this function should be called to compute and store the residual for future steps.
116
+
117
+ Args:
118
+ model_prediction_hidden_states: The hidden states tensor output from running the denoising layers.
119
+ """
120
+
121
+ current_residual = model_prediction_hidden_states - self.hidden_states
122
+ if self.is_calibrating:
123
+ self._update_calibration_stats(current_residual)
124
+
125
+ self.previous_residual = current_residual
126
+
127
+ def _update_calibration_stats(self, current_residual):
128
+ if self.step_index >= 1:
129
+ norm_ratio = ((current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).mean()).item()
130
+ norm_std = (current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).std().item()
131
+ cos_dis = (1-torch.nn.functional.cosine_similarity(current_residual, self.previous_residual, dim=-1, eps=1e-8)).mean().item()
132
+ self.norm_ratio.append(round(norm_ratio, 5))
133
+ self.norm_std.append(round(norm_std, 5))
134
+ self.cos_dis.append(round(cos_dis, 5))
135
+ # print(f"time: {self.step_index}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}")
136
+
137
+ self.step_index += 1
138
+ if self.step_index == self.num_steps:
139
+ print("norm ratio")
140
+ print(self.norm_ratio)
141
+ print("norm std")
142
+ print(self.norm_std)
143
+ print("cos_dis")
144
+ print(self.cos_dis)
145
+ self.step_index = 0
146
+
147
+ def _determine_mag_ratios(self):
148
+ """
149
+ Determines the magnitude ratios by finding the closest resolution and step count
150
+ in the pre-calibrated database.
151
+
152
+ Returns:
153
+ A numpy array of magnitude ratios for the specified configuration, or None if not found.
154
+ """
155
+ if self.is_calibrating:
156
+ return None
157
+ try:
158
+ # Find the closest available resolution group for the given model family
159
+ resolution_groups = MAG_RATIOS_DB[self.model_family]
160
+ available_resolutions = list(resolution_groups.keys())
161
+ if not available_resolutions:
162
+ raise ValueError("No resolutions defined for this model family.")
163
+
164
+ avg_resolution = (self.height + self.width) / 2.0
165
+ closest_resolution_key = min(available_resolutions, key=lambda r: abs(r - avg_resolution))
166
+
167
+ # Find the closest available step count for the given model/resolution
168
+ steps_group = resolution_groups[closest_resolution_key]
169
+ available_steps = list(steps_group.keys())
170
+ if not available_steps:
171
+ raise ValueError(f"No step counts defined for resolution {closest_resolution_key}.")
172
+ closest_steps = min(available_steps, key=lambda x: abs(x - self.num_steps))
173
+ base_ratios = steps_group[closest_steps]
174
+ if closest_steps == self.num_steps:
175
+ print(f"MagCache: Found ratios for {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {self.num_steps} steps.")
176
+ return base_ratios
177
+ print(f"MagCache: Using ratios from {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {closest_steps} steps and interpolating to {self.num_steps} steps.")
178
+ return self._nearest_step_interpolation(base_ratios, self.num_steps)
179
+ except KeyError:
180
+ # This will catch if model_family is not in MAG_RATIOS_DB
181
+ print(f"Warning: MagCache not calibrated for model family '{self.model_family}'. MagCache will not be used.")
182
+ self.is_enabled = False
183
+ except (ValueError, TypeError) as e:
184
+ # This will catch errors if resolution keys or step keys are not numbers, or if groups are empty.
185
+ print(f"Warning: Error processing MagCache DB for model family '{self.model_family}': {e}. MagCache will not be used.")
186
+ self.is_enabled = False
187
+ return None
188
+
189
+ # Nearest interpolation function for MagCache mag_ratios
190
+ @staticmethod
191
+ def _nearest_step_interpolation(src_array, target_length):
192
+ src_length = len(src_array)
193
+ if target_length == 1:
194
+ return np.array([src_array[-1]])
195
+
196
+ scale = (src_length - 1) / (target_length - 1)
197
+ mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
198
+ return src_array[mapped_indices]
199
+
200
+ def append_calibration_to_file(self, output_file):
201
+ """
202
+ Appends tab delimited calibration data (model_family,width,height,norm_ratio) to output_file.
203
+ """
204
+ if not self.is_calibrating or not self.norm_ratio:
205
+ print("Calibration data can only be appended after calibration.")
206
+ return False
207
+ try:
208
+ with open(output_file, "a") as f:
209
+ # Format the data as a string
210
+ calibration_set = f"{self.model_family}\t{self.width}\t{self.height}\t{self.num_steps}"
211
+ # data_string = f"{calibration_set}\t{self.norm_ratio}"
212
+ entry_string = f"{calibration_set}\t{self.num_steps}: np.array([1.0] + {self.norm_ratio}),"
213
+ # Append the data to the file
214
+ f.write(entry_string + "\n")
215
+ print(f"Calibration data appended to {output_file}")
216
+ return True
217
+ except Exception as e:
218
+ print(f"Error appending calibration data: {e}")
219
+ return False
diffusers_helper/models/mag_cache_ratios.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Pre-calibrated magnitude ratios for different model families, resolutions, and step counts
4
+ # Format: MAG_RATIOS_DB[model_family][resolution][step_count] = np.array([...])
5
+ # All calibrations performed with FramePackStudio v0.4 with default settings and seed 31337
6
+ MAG_RATIOS_DB = {
7
+ "Original": {
8
+ 768: {
9
+ 25: np.array([1.0] + [1.30469, 1.22656, 1.03906, 1.02344, 1.03906, 1.01562, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.04688, 0.99219, 1.00781, 1.00781, 0.98828, 0.94141, 0.93359, 0.78906]),
10
+ 50: np.array([1.0] + [1.30469, 0.99609, 1.16406, 1.0625, 1.01562, 1.02344, 1.01562, 1.0, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.99219, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.98047, 0.91016, 0.85938, 0.78125]),
11
+ 75: np.array([1.0] + [1.01562, 1.27344, 1.0, 1.15625, 1.0625, 1.00781, 1.02344, 1.0, 1.02344, 1.0, 1.02344, 1.0, 1.0, 1.04688, 0.99609, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99609, 1.00781, 0.99219, 0.99609, 1.00781, 1.03125, 0.98438, 1.01562, 1.02344, 0.98828, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98438, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.0, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 0.99609, 0.96484, 0.97266, 0.94531, 0.91406, 0.90234, 0.85938, 0.76172]),
12
+ },
13
+ 640: {
14
+ 25: np.array([1.0] + [1.30469, 1.22656, 1.05469, 1.02344, 1.03906, 1.02344, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.03906, 0.98828, 1.0, 1.0, 0.98828, 0.94531, 0.93359, 0.78516]),
15
+ 50: np.array([1.0] + [1.28906, 1.0, 1.17188, 1.0625, 1.02344, 1.02344, 1.02344, 1.0, 1.04688, 0.99219, 1.00781, 1.01562, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.99219, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.97656, 0.91016, 0.85547, 0.76953]),
16
+ 75: np.array([1.0] + [1.00781, 1.30469, 1.0, 1.15625, 1.05469, 1.01562, 1.01562, 1.0, 1.01562, 1.0, 1.02344, 0.99609, 1.0, 1.04688, 0.99219, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99219, 1.00781, 0.99219, 0.99609, 1.00781, 1.03906, 0.98828, 1.01562, 1.02344, 0.98828, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98828, 1.03906, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.00781, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 0.99609, 0.96484, 0.97266, 0.94922, 0.91797, 0.90625, 0.86328, 0.75781]),
17
+ },
18
+ 512: {
19
+ 25: np.array([1.0] + [1.32031, 1.21875, 1.03906, 1.02344, 1.03906, 1.01562, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.04688, 0.98828, 1.0, 1.00781, 0.98828, 0.94141, 0.9375, 0.78516]),
20
+ 50: np.array([1.0] + [1.32031, 0.99609, 1.15625, 1.0625, 1.01562, 1.02344, 1.02344, 1.0, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.98828, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.98047, 0.91016, 0.85938, 0.77734]),
21
+ 75: np.array([1.0] + [1.02344, 1.28906, 1.0, 1.15625, 1.0625, 1.01562, 1.01562, 1.0, 1.02344, 1.0, 1.02344, 1.0, 1.0, 1.04688, 0.99609, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99609, 1.00781, 0.99219, 0.99609, 1.00781, 1.03125, 0.98828, 1.01562, 1.02344, 0.99219, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98828, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.00781, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 1.0, 0.96484, 0.97266, 0.94922, 0.91797, 0.90625, 0.86328, 0.75781]),
22
+ },
23
+ 384: {
24
+ 25: np.array([1.0] + [1.58594, 1.0625, 1.03906, 1.02344, 1.0625, 1.04688, 1.04688, 1.03125, 1.02344, 1.01562, 1.00781, 1.01562, 1.01562, 0.99219, 1.07031, 0.96094, 0.96484, 1.03125, 0.96875, 0.94141, 0.97266, 0.92188, 0.88672, 0.75]),
25
+ 50: np.array([1.0] + [1.29688, 1.21875, 1.02344, 1.03906, 1.0, 1.03906, 1.00781, 1.02344, 1.05469, 1.00781, 1.03125, 1.01562, 1.04688, 1.00781, 0.98828, 1.03906, 0.99609, 1.03125, 1.03125, 0.98438, 1.01562, 0.99609, 1.01562, 1.0, 1.01562, 1.0, 1.01562, 0.98047, 1.02344, 1.04688, 0.97266, 0.98828, 0.97656, 0.98828, 1.05469, 0.97656, 0.98828, 0.98047, 0.98438, 0.95703, 1.00781, 0.96484, 0.97656, 0.94531, 0.94141, 0.94531, 0.875, 0.85547, 0.79688]),
26
+ 75: np.array([1.0] + [1.29688, 1.14844, 1.07031, 1.01562, 1.02344, 1.02344, 0.99609, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.00781, 1.03125, 1.00781, 1.00781, 1.04688, 0.99219, 1.00781, 0.99219, 1.00781, 1.03125, 0.98438, 1.01562, 1.02344, 0.98828, 1.01562, 1.01562, 1.0, 1.0, 1.00781, 0.99219, 1.01562, 1.0, 0.99609, 1.03125, 0.98828, 1.02344, 0.99609, 0.97656, 1.01562, 1.00781, 1.03906, 0.96484, 1.03125, 0.96484, 1.02344, 0.98438, 0.96094, 1.03125, 0.98828, 1.01562, 0.97266, 1.02344, 0.97656, 1.0, 0.98438, 0.95703, 1.02344, 0.96094, 0.99609, 0.99609, 0.9375, 0.98438, 0.94141, 0.97266, 0.96875, 0.89844, 0.95703, 0.87109, 0.86328, 0.85547]),
27
+ },
28
+ 256: {
29
+ 25: np.array([1.0] + [1.59375, 1.10156, 1.08594, 1.05469, 1.03906, 1.03125, 1.03125, 1.02344, 1.01562, 1.02344, 0.98438, 1.0625, 0.96875, 1.00781, 0.98438, 1.00781, 0.92969, 0.97656, 0.99609, 0.91406, 0.94922, 0.88672, 0.86328, 0.75391]),
30
+ 50: np.array([1.0] + [1.46875, 1.10156, 1.04688, 1.03906, 1.02344, 1.0625, 1.03125, 1.02344, 1.03906, 1.0, 1.01562, 1.01562, 1.03125, 0.99609, 1.01562, 1.00781, 0.99609, 1.02344, 1.01562, 1.00781, 1.00781, 0.98047, 1.02344, 1.04688, 0.97266, 0.99609, 1.0, 1.00781, 0.98047, 1.00781, 0.98047, 1.02344, 0.96094, 0.96875, 1.03125, 0.94531, 0.98047, 1.01562, 0.96484, 0.94531, 0.99609, 0.95312, 0.96484, 0.91406, 0.92969, 0.92969, 0.88672, 0.85156, 0.89062]),
31
+ 75: np.array([1.0] + [1.25781, 1.23438, 1.04688, 1.03906, 1.0, 1.03906, 1.02344, 1.00781, 1.05469, 1.00781, 1.03125, 1.01562, 1.04688, 0.99219, 1.00781, 1.0, 1.03906, 0.99609, 1.03125, 0.98828, 1.00781, 1.00781, 1.02344, 0.99219, 1.00781, 1.00781, 1.0, 0.99609, 1.03125, 0.99609, 1.01562, 0.99609, 0.97656, 1.03906, 0.98438, 1.03906, 0.96484, 1.01562, 0.98438, 1.02344, 0.95312, 1.03906, 0.98047, 1.02344, 0.98047, 1.0, 0.97656, 1.03906, 0.94922, 1.01562, 0.97266, 1.01562, 0.98828, 0.97266, 1.01562, 0.97656, 1.00781, 0.9375, 1.0, 0.97266, 1.00781, 0.96875, 0.97266, 0.96875, 0.93359, 0.98047, 0.95703, 0.94922, 0.94922, 0.92578, 0.94531, 0.85938, 0.91797, 0.94531]),
32
+ },
33
+ 128: {
34
+ 25: np.array([1.0] + [1.63281, 1.0625, 1.14062, 1.04688, 1.03906, 1.03125, 1.03125, 1.02344, 0.99219, 1.03125, 0.96484, 1.02344, 0.97266, 0.98438, 0.97656, 0.96875, 0.95312, 0.95312, 0.97656, 0.92188, 0.9375, 0.87891, 0.85156, 0.82812]),
35
+ 50: np.array([1.0] + [1.5625, 1.05469, 1.03906, 1.03125, 1.07031, 1.03906, 1.05469, 0.99609, 1.03906, 1.0, 1.05469, 0.97656, 1.01562, 1.01562, 1.0, 1.03125, 1.01562, 0.97656, 0.99609, 1.03906, 0.98828, 0.98047, 1.03125, 0.99609, 0.97266, 1.0, 0.99609, 0.98438, 0.97266, 1.00781, 0.98828, 0.98438, 0.97656, 0.98047, 0.99609, 0.95703, 1.00781, 0.96484, 0.99219, 0.92969, 0.98828, 0.94922, 0.94531, 0.92969, 0.92578, 0.92188, 0.91797, 0.90625, 1.00781]),
36
+ 75: np.array([1.0] + [1.47656, 1.08594, 1.02344, 1.0625, 1.0, 1.02344, 1.0625, 1.03906, 1.01562, 1.0, 1.04688, 1.0, 1.01562, 1.00781, 1.01562, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.01562, 0.98438, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 1.00781, 1.01562, 0.97656, 1.00781, 0.98438, 1.03906, 0.97656, 1.01562, 0.94531, 1.03125, 1.00781, 0.98438, 1.02344, 0.97656, 1.00781, 0.95703, 1.01562, 0.97656, 1.02344, 0.97266, 0.96484, 1.02344, 0.96094, 0.99609, 0.99609, 0.96094, 1.0, 1.00781, 0.97266, 0.98828, 0.96875, 0.96484, 0.98828, 0.95703, 0.99219, 0.97266, 0.89844, 1.0, 0.96094, 0.92578, 0.95703, 0.9375, 0.91016, 0.97266, 0.96875, 1.07812]),
37
+ },
38
+ },
39
+ "F1": {
40
+ 768: {
41
+ 25: np.array([1.0] + [1.27344, 1.08594, 1.03125, 1.00781, 1.00781, 1.00781, 1.03125, 1.03906, 1.00781, 1.03125, 0.98828, 1.01562, 1.00781, 1.01562, 1.00781, 0.98438, 1.04688, 0.98438, 0.96875, 1.03125, 0.97266, 0.92188, 0.95703, 0.77734]),
42
+ 50: np.array([1.0] + [1.27344, 1.0, 1.07031, 1.01562, 1.0, 1.02344, 1.0, 1.01562, 1.02344, 0.98828, 1.00781, 1.00781, 1.0, 1.02344, 1.00781, 1.03125, 1.0, 1.00781, 0.97656, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.00781, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.95312, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.84375, 0.76562]),
43
+ 75: np.array([1.0] + [1.0, 1.26562, 1.00781, 1.07812, 1.0, 1.00781, 1.01562, 1.0, 1.00781, 1.0, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.03125, 1.01562, 1.00781, 1.02344, 1.0, 0.99219, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98828, 1.01562, 1.03125, 0.97266, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96875, 1.0625, 0.98828, 1.00781, 0.99609, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.96094, 1.00781, 0.96875, 1.01562, 0.98828, 0.99609, 0.95703, 0.96875, 1.02344, 0.96875, 0.96484, 0.95312, 0.89844, 0.90234, 0.86719, 0.76562]),
44
+ },
45
+ 640: {
46
+ 25: np.array([1.0] + [1.27344, 1.07031, 1.03906, 1.00781, 1.00781, 1.00781, 1.03125, 1.04688, 1.00781, 1.03125, 0.99219, 1.01562, 1.01562, 1.01562, 1.00781, 0.98438, 1.05469, 0.98438, 0.96875, 1.03125, 0.97266, 0.92578, 0.95703, 0.77734]),
47
+ 50: np.array([1.0] + [1.27344, 1.0, 1.07812, 1.01562, 1.0, 1.01562, 1.00781, 1.00781, 1.02344, 0.98828, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.03125, 1.0, 1.00781, 0.98047, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.01562, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.95312, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.84375, 0.76953]),
48
+ 75: np.array([1.0] + [1.0, 1.27344, 1.0, 1.07031, 1.01562, 0.99609, 1.00781, 1.0, 1.00781, 1.0, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.00781, 1.02344, 1.0, 0.99219, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98438, 1.01562, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96875, 1.0625, 0.98828, 1.00781, 1.0, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.96094, 1.00781, 0.96875, 1.01562, 0.98828, 0.99609, 0.95703, 0.96875, 1.02344, 0.96484, 0.96484, 0.95312, 0.89844, 0.90234, 0.87109, 0.76953]),
49
+ },
50
+ 512: {
51
+ 25: np.array([1.0] + [1.28125, 1.08594, 1.02344, 1.01562, 1.00781, 1.00781, 1.03125, 1.03906, 1.00781, 1.03125, 0.98828, 1.01562, 1.00781, 1.01562, 1.00781, 0.98438, 1.04688, 0.98438, 0.96875, 1.03125, 0.97656, 0.92188, 0.96094, 0.77734]),
52
+ 50: np.array([1.0] + [1.28125, 1.00781, 1.08594, 1.0, 1.01562, 1.01562, 1.00781, 1.00781, 1.02344, 0.98438, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.03125, 1.0, 1.00781, 0.97656, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.00781, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.94922, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.83984, 0.76953]),
53
+ 75: np.array([1.0] + [1.00781, 1.27344, 1.00781, 1.07812, 1.0, 1.00781, 1.00781, 0.99609, 1.01562, 0.99609, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.00781, 1.02344, 1.0, 0.98828, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98438, 1.00781, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96484, 1.0625, 0.98828, 1.0, 0.99609, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.95703, 1.00781, 0.96484, 1.02344, 0.98828, 0.99609, 0.95703, 0.96875, 1.03125, 0.96875, 0.96484, 0.95703, 0.89844, 0.90234, 0.87109, 0.76953]),
54
+ },
55
+ 384: {
56
+ 25: np.array([1.0] + [1.36719, 1.03125, 1.02344, 1.01562, 1.04688, 1.03125, 1.04688, 1.02344, 1.00781, 0.99609, 1.01562, 0.99219, 1.00781, 0.97266, 1.07812, 0.95703, 0.9375, 1.04688, 0.98828, 0.89844, 1.00781, 0.92188, 0.89844, 0.72656]),
57
+ 50: np.array([1.0] + [1.27344, 1.08594, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.00781, 1.03906, 1.00781, 1.02344, 1.00781, 1.03125, 1.01562, 0.98047, 1.04688, 0.98438, 1.02344, 1.02344, 0.97656, 1.03125, 0.98828, 1.00781, 0.98828, 1.00781, 1.0, 0.99609, 0.97656, 1.03125, 1.04688, 0.97656, 0.98047, 0.97266, 0.96094, 1.09375, 0.95703, 0.98438, 1.0, 0.96094, 0.93359, 1.03125, 0.97266, 1.0, 0.92188, 0.93359, 0.96484, 0.85156, 0.84375, 0.79688]),
58
+ 75: np.array([1.0] + [1.26562, 1.08594, 1.00781, 1.00781, 1.01562, 1.00781, 1.00781, 1.03125, 0.98438, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.00781, 1.02344, 0.98828, 1.01562, 1.03125, 1.0, 1.01562, 0.98047, 1.01562, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.0, 1.02344, 1.00781, 1.0, 1.00781, 0.97266, 1.01562, 1.00781, 0.97656, 1.0625, 0.97656, 1.02344, 0.98828, 0.96484, 1.02344, 1.01562, 1.04688, 0.95312, 1.03125, 0.98047, 1.01562, 0.97266, 0.94922, 1.0625, 0.96484, 1.02344, 0.98438, 1.02344, 0.98047, 0.97266, 0.99219, 0.92969, 1.07031, 0.96094, 0.98047, 0.98438, 0.94531, 0.98828, 0.9375, 0.97266, 0.98828, 0.86719, 0.98047, 0.84766, 0.86328, 0.86719]),
59
+ },
60
+ 256: {
61
+ 25: np.array([1.0] + [1.38281, 1.04688, 1.05469, 1.03906, 1.03906, 1.01562, 1.0, 1.03906, 1.00781, 1.03125, 0.96094, 1.08594, 0.96094, 1.00781, 0.98438, 1.02344, 0.91016, 0.99609, 1.0, 0.90234, 0.97266, 0.87109, 0.85547, 0.71875]),
62
+ 50: np.array([1.0] + [1.375, 1.02344, 1.02344, 1.01562, 1.00781, 1.04688, 1.03125, 1.00781, 1.03125, 1.00781, 1.0, 1.01562, 1.01562, 0.98438, 1.03125, 1.00781, 0.98438, 1.02344, 1.01562, 1.02344, 0.98047, 0.97656, 1.03125, 1.04688, 0.97656, 0.98047, 0.99219, 1.01562, 0.99609, 0.98828, 0.99219, 1.03125, 0.96875, 0.94141, 1.05469, 0.94531, 0.95312, 1.04688, 0.94141, 0.96094, 1.00781, 0.96094, 0.95312, 0.91016, 0.91797, 0.93359, 0.89062, 0.8125, 0.83984]),
63
+ 75: np.array([1.0] + [1.27344, 1.09375, 1.00781, 1.01562, 1.00781, 1.00781, 1.00781, 1.00781, 1.03906, 1.00781, 1.02344, 1.00781, 1.03125, 1.0, 1.00781, 1.0, 1.03125, 0.98828, 1.02344, 0.97266, 1.0, 1.02344, 1.03125, 0.98047, 0.99609, 1.02344, 0.99219, 0.98047, 1.0625, 0.99219, 1.00781, 0.98828, 0.96875, 1.0625, 0.98047, 1.04688, 0.95312, 1.03125, 0.98047, 1.03125, 0.94922, 1.03125, 0.99609, 1.03125, 0.95703, 1.0, 0.98438, 1.03906, 0.9375, 1.01562, 0.96094, 1.03125, 0.98828, 0.98047, 1.00781, 0.99219, 1.0, 0.91016, 1.01562, 0.97266, 1.00781, 0.98047, 0.98438, 0.97266, 0.89062, 1.00781, 0.95703, 0.95312, 0.94141, 0.9375, 0.93359, 0.82031, 0.91016, 0.87891]),
64
+ },
65
+ 128: {
66
+ 25: np.array([1.0] + [1.42188, 1.03125, 1.0625, 1.04688, 1.02344, 1.03125, 1.03125, 1.02344, 0.99219, 1.02344, 0.94531, 1.04688, 0.94922, 1.01562, 0.98047, 0.96484, 0.93359, 0.96484, 0.98438, 0.91406, 0.97266, 0.87891, 0.85547, 0.83203]),
67
+ 50: np.array([1.0] + [1.375, 1.03125, 1.02344, 1.01562, 1.04688, 1.01562, 1.04688, 1.0, 1.04688, 0.98438, 1.05469, 0.97266, 1.01562, 1.01562, 0.98828, 1.03906, 1.01562, 0.97656, 0.98047, 1.03906, 0.97266, 0.97266, 1.0625, 0.99219, 0.97656, 0.97266, 1.02344, 0.99609, 0.94141, 1.03906, 0.97266, 0.99219, 0.96875, 0.96484, 1.00781, 0.95312, 1.03906, 0.94922, 0.99609, 0.91797, 1.00781, 0.96484, 0.9375, 0.93359, 0.91797, 0.92969, 0.91406, 0.90625, 0.92188]),
68
+ 75: np.array([1.0] + [1.36719, 1.02344, 1.01562, 1.03906, 0.99219, 1.00781, 1.03906, 1.03125, 0.99219, 0.99609, 1.05469, 0.99609, 1.01562, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.01562, 1.00781, 1.0, 0.96875, 1.05469, 0.99219, 1.00781, 1.0, 1.0, 1.01562, 1.00781, 0.96875, 1.02344, 0.94922, 1.07812, 0.94922, 1.03125, 0.92578, 1.05469, 0.97266, 0.98438, 1.04688, 0.98438, 1.00781, 0.95312, 1.02344, 0.94922, 1.04688, 0.96875, 0.96484, 1.03906, 0.93359, 1.00781, 1.00781, 0.95312, 1.00781, 1.0, 0.97656, 1.0, 0.94922, 0.96094, 1.02344, 0.92969, 1.02344, 0.96094, 0.88281, 1.03125, 0.94141, 0.91797, 0.98438, 0.92578, 0.90234, 0.99219, 0.92188, 0.98438]),
69
+ },
70
+ }
71
+ }
diffusers_helper/pipelines/k_diffusion_hunyuan.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
5
+ from diffusers_helper.k_diffusion.wrapper import fm_wrapper
6
+ from diffusers_helper.utils import repeat_to_batch_size
7
+
8
+
9
+ def flux_time_shift(t, mu=1.15, sigma=1.0):
10
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
11
+
12
+
13
+ def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
14
+ k = (y2 - y1) / (x2 - x1)
15
+ b = y1 - k * x1
16
+ mu = k * context_length + b
17
+ mu = min(mu, math.log(exp_max))
18
+ return mu
19
+
20
+
21
+ def get_flux_sigmas_from_mu(n, mu):
22
+ sigmas = torch.linspace(1, 0, steps=n + 1)
23
+ sigmas = flux_time_shift(sigmas, mu=mu)
24
+ return sigmas
25
+
26
+
27
+ @torch.inference_mode()
28
+ def sample_hunyuan(
29
+ transformer,
30
+ sampler='unipc',
31
+ initial_latent=None,
32
+ concat_latent=None,
33
+ strength=1.0,
34
+ width=512,
35
+ height=512,
36
+ frames=16,
37
+ real_guidance_scale=1.0,
38
+ distilled_guidance_scale=6.0,
39
+ guidance_rescale=0.0,
40
+ shift=None,
41
+ num_inference_steps=25,
42
+ batch_size=None,
43
+ generator=None,
44
+ prompt_embeds=None,
45
+ prompt_embeds_mask=None,
46
+ prompt_poolers=None,
47
+ negative_prompt_embeds=None,
48
+ negative_prompt_embeds_mask=None,
49
+ negative_prompt_poolers=None,
50
+ dtype=torch.bfloat16,
51
+ device=None,
52
+ negative_kwargs=None,
53
+ callback=None,
54
+ **kwargs,
55
+ ):
56
+ device = device or transformer.device
57
+
58
+ if batch_size is None:
59
+ batch_size = int(prompt_embeds.shape[0])
60
+
61
+ latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32)
62
+
63
+ B, C, T, H, W = latents.shape
64
+ seq_length = T * H * W // 4
65
+
66
+ if shift is None:
67
+ mu = calculate_flux_mu(seq_length, exp_max=7.0)
68
+ else:
69
+ mu = math.log(shift)
70
+
71
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
72
+
73
+ k_model = fm_wrapper(transformer)
74
+
75
+ if initial_latent is not None:
76
+ sigmas = sigmas * strength
77
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
78
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
79
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
80
+
81
+ if concat_latent is not None:
82
+ concat_latent = concat_latent.to(latents)
83
+
84
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
85
+
86
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
87
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
88
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
89
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
90
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
91
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
92
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
93
+
94
+ sampler_kwargs = dict(
95
+ dtype=dtype,
96
+ cfg_scale=real_guidance_scale,
97
+ cfg_rescale=guidance_rescale,
98
+ concat_latent=concat_latent,
99
+ positive=dict(
100
+ pooled_projections=prompt_poolers,
101
+ encoder_hidden_states=prompt_embeds,
102
+ encoder_attention_mask=prompt_embeds_mask,
103
+ guidance=distilled_guidance,
104
+ **kwargs,
105
+ ),
106
+ negative=dict(
107
+ pooled_projections=negative_prompt_poolers,
108
+ encoder_hidden_states=negative_prompt_embeds,
109
+ encoder_attention_mask=negative_prompt_embeds_mask,
110
+ guidance=distilled_guidance,
111
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
112
+ )
113
+ )
114
+
115
+ if sampler == 'unipc':
116
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
117
+ else:
118
+ raise NotImplementedError(f'Sampler {sampler} is not supported.')
119
+
120
+ return results
diffusers_helper/thread_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from threading import Thread, Lock
4
+
5
+
6
+ class Listener:
7
+ task_queue = []
8
+ lock = Lock()
9
+ thread = None
10
+
11
+ @classmethod
12
+ def _process_tasks(cls):
13
+ while True:
14
+ task = None
15
+ with cls.lock:
16
+ if cls.task_queue:
17
+ task = cls.task_queue.pop(0)
18
+
19
+ if task is None:
20
+ time.sleep(0.001)
21
+ continue
22
+
23
+ func, args, kwargs = task
24
+ try:
25
+ func(*args, **kwargs)
26
+ except Exception as e:
27
+ print(f"Error in listener thread: {e}")
28
+
29
+ @classmethod
30
+ def add_task(cls, func, *args, **kwargs):
31
+ with cls.lock:
32
+ cls.task_queue.append((func, args, kwargs))
33
+
34
+ if cls.thread is None:
35
+ cls.thread = Thread(target=cls._process_tasks, daemon=True)
36
+ cls.thread.start()
37
+
38
+
39
+ def async_run(func, *args, **kwargs):
40
+ Listener.add_task(func, *args, **kwargs)
41
+
42
+
43
+ class FIFOQueue:
44
+ def __init__(self):
45
+ self.queue = []
46
+ self.lock = Lock()
47
+
48
+ def push(self, item):
49
+ with self.lock:
50
+ self.queue.append(item)
51
+
52
+ def pop(self):
53
+ with self.lock:
54
+ if self.queue:
55
+ return self.queue.pop(0)
56
+ return None
57
+
58
+ def top(self):
59
+ with self.lock:
60
+ if self.queue:
61
+ return self.queue[0]
62
+ return None
63
+
64
+ def next(self):
65
+ while True:
66
+ with self.lock:
67
+ if self.queue:
68
+ return self.queue.pop(0)
69
+
70
+ time.sleep(0.001)
71
+
72
+
73
+ class AsyncStream:
74
+ def __init__(self):
75
+ self.input_queue = FIFOQueue()
76
+ self.output_queue = FIFOQueue()
diffusers_helper/utils.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top:top + target_height, left:left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, 'rt', encoding='utf-8') as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = ['.lora_B.', '__zero__']
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024 ** 2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, '_forward_inside_frozen_module'):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError('No file to resume!')
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(', ')
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ', '.join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
280
+ return x
281
+
282
+
283
+ def save_bcthw_as_png(x, output_filename):
284
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
285
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
286
+ x = x.detach().cpu().to(torch.uint8)
287
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
288
+ torchvision.io.write_png(x, output_filename)
289
+ return output_filename
290
+
291
+
292
+ def save_bchw_as_png(x, output_filename):
293
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
294
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
295
+ x = x.detach().cpu().to(torch.uint8)
296
+ x = einops.rearrange(x, 'b c h w -> c h (b w)')
297
+ torchvision.io.write_png(x, output_filename)
298
+ return output_filename
299
+
300
+
301
+ def add_tensors_with_padding(tensor1, tensor2):
302
+ if tensor1.shape == tensor2.shape:
303
+ return tensor1 + tensor2
304
+
305
+ shape1 = tensor1.shape
306
+ shape2 = tensor2.shape
307
+
308
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
309
+
310
+ padded_tensor1 = torch.zeros(new_shape)
311
+ padded_tensor2 = torch.zeros(new_shape)
312
+
313
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
314
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
315
+
316
+ result = padded_tensor1 + padded_tensor2
317
+ return result
318
+
319
+
320
+ def print_free_mem():
321
+ torch.cuda.empty_cache()
322
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
323
+ free_mem_mb = free_mem / (1024 ** 2)
324
+ total_mem_mb = total_mem / (1024 ** 2)
325
+ print(f"Free memory: {free_mem_mb:.2f} MB")
326
+ print(f"Total memory: {total_mem_mb:.2f} MB")
327
+ return
328
+
329
+
330
+ def print_gpu_parameters(device, state_dict, log_count=1):
331
+ summary = {"device": device, "keys_count": len(state_dict)}
332
+
333
+ logged_params = {}
334
+ for i, (key, tensor) in enumerate(state_dict.items()):
335
+ if i >= log_count:
336
+ break
337
+ logged_params[key] = tensor.flatten()[:3].tolist()
338
+
339
+ summary["params"] = logged_params
340
+
341
+ print(str(summary))
342
+ return
343
+
344
+
345
+ def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
346
+ from PIL import Image, ImageDraw, ImageFont
347
+
348
+ txt = Image.new("RGB", (width, height), color="white")
349
+ draw = ImageDraw.Draw(txt)
350
+ font = ImageFont.truetype(font_path, size=size)
351
+
352
+ if text == '':
353
+ return np.array(txt)
354
+
355
+ # Split text into lines that fit within the image width
356
+ lines = []
357
+ words = text.split()
358
+ current_line = words[0]
359
+
360
+ for word in words[1:]:
361
+ line_with_word = f"{current_line} {word}"
362
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
363
+ current_line = line_with_word
364
+ else:
365
+ lines.append(current_line)
366
+ current_line = word
367
+
368
+ lines.append(current_line)
369
+
370
+ # Draw the text line by line
371
+ y = 0
372
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
373
+
374
+ for line in lines:
375
+ if y + line_height > height:
376
+ break # stop drawing if the next line will be outside the image
377
+ draw.text((0, y), line, fill="black", font=font)
378
+ y += line_height
379
+
380
+ return np.array(txt)
381
+
382
+
383
+ def blue_mark(x):
384
+ x = x.copy()
385
+ c = x[:, :, 2]
386
+ b = cv2.blur(c, (9, 9))
387
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
388
+ return x
389
+
390
+
391
+ def green_mark(x):
392
+ x = x.copy()
393
+ x[:, :, 2] = -1
394
+ x[:, :, 0] = -1
395
+ return x
396
+
397
+
398
+ def frame_mark(x):
399
+ x = x.copy()
400
+ x[:64] = -1
401
+ x[-64:] = -1
402
+ x[:, :8] = 1
403
+ x[:, -8:] = 1
404
+ return x
405
+
406
+
407
+ @torch.inference_mode()
408
+ def pytorch2numpy(imgs):
409
+ results = []
410
+ for x in imgs:
411
+ y = x.movedim(0, -1)
412
+ y = y * 127.5 + 127.5
413
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
414
+ results.append(y)
415
+ return results
416
+
417
+
418
+ @torch.inference_mode()
419
+ def numpy2pytorch(imgs):
420
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
421
+ h = h.movedim(-1, 1)
422
+ return h
423
+
424
+
425
+ @torch.no_grad()
426
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
427
+ if zero_out:
428
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
429
+ else:
430
+ return torch.cat([x, x[:count]], dim=0)
431
+
432
+
433
+ def weighted_mse(a, b, weight):
434
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
435
+
436
+
437
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
438
+ x = (x - x_min) / (x_max - x_min)
439
+ x = max(0.0, min(x, 1.0))
440
+ x = x ** sigma
441
+ return y_min + x * (y_max - y_min)
442
+
443
+
444
+ def expand_to_dims(x, target_dims):
445
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
446
+
447
+
448
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
449
+ if tensor is None:
450
+ return None
451
+
452
+ first_dim = tensor.shape[0]
453
+
454
+ if first_dim == batch_size:
455
+ return tensor
456
+
457
+ if batch_size % first_dim != 0:
458
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
459
+
460
+ repeat_times = batch_size // first_dim
461
+
462
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
463
+
464
+
465
+ def dim5(x):
466
+ return expand_to_dims(x, 5)
467
+
468
+
469
+ def dim4(x):
470
+ return expand_to_dims(x, 4)
471
+
472
+
473
+ def dim3(x):
474
+ return expand_to_dims(x, 3)
475
+
476
+
477
+ def crop_or_pad_yield_mask(x, length):
478
+ B, F, C = x.shape
479
+ device = x.device
480
+ dtype = x.dtype
481
+
482
+ if F < length:
483
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
484
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
485
+ y[:, :F, :] = x
486
+ mask[:, :F] = True
487
+ return y, mask
488
+
489
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
490
+
491
+
492
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
493
+ original_length = int(x.shape[dim])
494
+
495
+ if original_length >= minimal_length:
496
+ return x
497
+
498
+ if zero_pad:
499
+ padding_shape = list(x.shape)
500
+ padding_shape[dim] = minimal_length - original_length
501
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
502
+ else:
503
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
504
+ last_element = x[idx]
505
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
506
+
507
+ return torch.cat([x, padding], dim=dim)
508
+
509
+
510
+ def lazy_positional_encoding(t, repeats=None):
511
+ if not isinstance(t, list):
512
+ t = [t]
513
+
514
+ from diffusers.models.embeddings import get_timestep_embedding
515
+
516
+ te = torch.tensor(t)
517
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
518
+
519
+ if repeats is None:
520
+ return te
521
+
522
+ te = te[:, None, :].expand(-1, repeats, -1)
523
+
524
+ return te
525
+
526
+
527
+ def state_dict_offset_merge(A, B, C=None):
528
+ result = {}
529
+ keys = A.keys()
530
+
531
+ for key in keys:
532
+ A_value = A[key]
533
+ B_value = B[key].to(A_value)
534
+
535
+ if C is None:
536
+ result[key] = A_value + B_value
537
+ else:
538
+ C_value = C[key].to(A_value)
539
+ result[key] = A_value + B_value - C_value
540
+
541
+ return result
542
+
543
+
544
+ def state_dict_weighted_merge(state_dicts, weights):
545
+ if len(state_dicts) != len(weights):
546
+ raise ValueError("Number of state dictionaries must match number of weights")
547
+
548
+ if not state_dicts:
549
+ return {}
550
+
551
+ total_weight = sum(weights)
552
+
553
+ if total_weight == 0:
554
+ raise ValueError("Sum of weights cannot be zero")
555
+
556
+ normalized_weights = [w / total_weight for w in weights]
557
+
558
+ keys = state_dicts[0].keys()
559
+ result = {}
560
+
561
+ for key in keys:
562
+ result[key] = state_dicts[0][key] * normalized_weights[0]
563
+
564
+ for i in range(1, len(state_dicts)):
565
+ state_dict_value = state_dicts[i][key].to(result[key])
566
+ result[key] += state_dict_value * normalized_weights[i]
567
+
568
+ return result
569
+
570
+
571
+ def group_files_by_folder(all_files):
572
+ grouped_files = {}
573
+
574
+ for file in all_files:
575
+ folder_name = os.path.basename(os.path.dirname(file))
576
+ if folder_name not in grouped_files:
577
+ grouped_files[folder_name] = []
578
+ grouped_files[folder_name].append(file)
579
+
580
+ list_of_lists = list(grouped_files.values())
581
+ return list_of_lists
582
+
583
+
584
+ def generate_timestamp():
585
+ now = datetime.datetime.now()
586
+ timestamp = now.strftime('%y%m%d_%H%M%S')
587
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
588
+ random_number = random.randint(0, 9999)
589
+ return f"{timestamp}_{milliseconds}_{random_number}"
590
+
591
+
592
+ def write_PIL_image_with_png_info(image, metadata, path):
593
+ from PIL.PngImagePlugin import PngInfo
594
+
595
+ png_info = PngInfo()
596
+ for key, value in metadata.items():
597
+ png_info.add_text(key, value)
598
+
599
+ image.save(path, "PNG", pnginfo=png_info)
600
+ return image
601
+
602
+
603
+ def torch_safe_save(content, path):
604
+ torch.save(content, path + '_tmp')
605
+ os.replace(path + '_tmp', path)
606
+ return path
607
+
608
+
609
+ def move_optimizer_to_device(optimizer, device):
610
+ for state in optimizer.state.values():
611
+ for k, v in state.items():
612
+ if isinstance(v, torch.Tensor):
613
+ state[k] = v.to(device)
docker-compose.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ studio:
3
+ build:
4
+ # modify this if you are building the image locally and need a different CUDA version
5
+ args:
6
+ - CUDA_VERSION=12.4.1
7
+ # modify the tag here if you need a different CUDA version or branch
8
+ image: colinurbs/fp-studio:cuda12.4-latest-develop
9
+ restart: unless-stopped
10
+ ports:
11
+ - "7860:7860"
12
+ volumes:
13
+ - "./loras:/app/loras"
14
+ - "./outputs:/app/outputs"
15
+ - "./.framepack:/app/.framepack"
16
+ - "./modules/toolbox/model_esrgan:/app/modules/toolbox/model_esrgan"
17
+ - "./modules/toolbox/model_rife:/app/modules/toolbox/model_rife"
18
+ - "$HOME/.cache/huggingface:/app/hf_download"
19
+ deploy:
20
+ resources:
21
+ reservations:
22
+ devices:
23
+ - driver: nvidia
24
+ count: 1
25
+ capabilities: [gpu]
install.bat ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo FramePack-Studio Setup Script
3
+ setlocal enabledelayedexpansion
4
+
5
+ REM Check if Python is installed (basic check)
6
+ where python >nul 2>&1
7
+ if %errorlevel% neq 0 (
8
+ echo Error: Python is not installed or not in your PATH. Please install Python and try again.
9
+ goto end
10
+ )
11
+
12
+ if exist "%cd%/venv" (
13
+ echo Virtual Environment already exists.
14
+ set /p choice= "Do you want to reinstall packages?[Y/N]: "
15
+
16
+ if "!choice!" == "y" (goto checkgpu)
17
+ if "!choice!"=="Y" (goto checkgpu)
18
+
19
+ goto end
20
+ )
21
+
22
+ REM Check the python version
23
+ echo Python versions 3.10-3.12 have been confirmed to work. Other versions are currently not supported. You currently have:
24
+ python -V
25
+ set choice=
26
+ set /p choice= "Do you want to continue?[Y/N]: "
27
+
28
+
29
+ if "!choice!" == "y" (goto makevenv)
30
+ if "!choice!"=="Y" (goto makevenv)
31
+
32
+ goto end
33
+
34
+ :makevenv
35
+ REM This creates a virtual environment in the folder
36
+ echo Creating a Virtual Environment...
37
+ python -m venv venv
38
+ echo Upgrading pip in Virtual Environment to lower chance of error...
39
+ "%cd%/venv/Scripts/python.exe" -m pip install --upgrade pip
40
+
41
+ :checkgpu
42
+ REM ask Windows for GPU
43
+ where nvidia-smi >nul 2>&1
44
+ if %errorlevel% neq 0 (
45
+ echo Error: Nvidia GPU doesn't exist or drivers installed incorrectly. Please confirm your drivers are installed.
46
+ goto end
47
+ )
48
+
49
+ echo Checking your GPU...
50
+
51
+ for /F "tokens=* skip=1" %%n in ('nvidia-smi --query-gpu=name') do set GPU_NAME=%%n && goto gpuchecked
52
+
53
+ :gpuchecked
54
+ echo Detected %GPU_NAME%
55
+ set "GPU_SERIES=%GPU_NAME:*RTX =%"
56
+ set "GPU_SERIES=%GPU_SERIES:~0,2%00"
57
+
58
+ REM This gets the shortened Python version for later use. e.g. 3.10.13 becomes 310.
59
+ for /f "delims=" %%A in ('python -V') do set "pyv=%%A"
60
+ for /f "tokens=2 delims= " %%A in ("%pyv%") do (
61
+ set pyv=%%A
62
+ )
63
+ set pyv=%pyv:.=%
64
+ set pyv=%pyv:~0,3%
65
+
66
+ echo Installing torch...
67
+
68
+ if !GPU_SERIES! geq 5000 (
69
+ goto torch270
70
+ ) else (
71
+ goto torch260
72
+ )
73
+
74
+ REM RTX 5000 Series
75
+ :torch270
76
+ "%cd%/venv/Scripts/pip.exe" install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --force-reinstall
77
+ REM Check if pip installation was successful
78
+ if %errorlevel% neq 0 (
79
+ echo Warning: Failed to install dependencies. You may need to install them manually.
80
+ goto end
81
+ )
82
+
83
+ REM Ask if user wants Sage Attention
84
+ set choice=
85
+ echo Do you want to install any of the following? They speed up generation.
86
+ echo 1) Sage Attention
87
+ echo 2) Flash Attention
88
+ echo 3) BOTH!
89
+ echo 4) No
90
+ set /p choice= "Input Selection: "
91
+
92
+ set both="N"
93
+
94
+ if "!choice!" == "1" (goto triton270)
95
+ if "!choice!"== "2" (goto flash270)
96
+ if "!choice!"== "3" (set both="Y"
97
+ goto triton270
98
+ )
99
+
100
+ goto requirements
101
+
102
+ :triton270
103
+ REM Sage Attention and Triton for Torch 2.7.0
104
+ "%cd%/venv/Scripts/pip.exe" install "triton-windows<3.4" --force-reinstall
105
+ "%cd%/venv/Scripts/pip.exe" install "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp%pyv%-cp%pyv%-win_amd64.whl" --force-reinstall
106
+ echo Finishing up installing triton-windows. This requires extraction of libraries into Python Folder...
107
+
108
+ REM Check for python version and download the triton-windows required libs accordingly
109
+ if %pyv% == 310 (
110
+ powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.10.11_include_libs.zip', 'triton-lib.zip')"
111
+ )
112
+
113
+ if %pyv% == 311 (
114
+ powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.11.9_include_libs.zip', 'triton-lib.zip')"
115
+ )
116
+
117
+ if %pyv% == 312 (
118
+ powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.12.7_include_libs.zip', 'triton-lib.zip')"
119
+ )
120
+
121
+ REM Extract the zip into the Python Folder and Delete zip
122
+ powershell Expand-Archive -Path '%cd%\triton-lib.zip' -DestinationPath '%cd%\venv\Scripts\' -force
123
+ del triton-lib.zip
124
+ if %both% == "Y" (goto flash270)
125
+
126
+ goto requirements
127
+
128
+ :flash270
129
+ REM Install flash-attn.
130
+ "%cd%/venv/Scripts/pip.exe" install "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1%%2Bcu128torch2.7.0cxx11abiFALSE-cp%pyv%-cp%pyv%-win_amd64.whl?download=true"
131
+ goto requirements
132
+
133
+
134
+ REM RTX 4000 Series and below
135
+ :torch260
136
+ "%cd%/venv/Scripts/pip.exe" install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 --force-reinstall
137
+ REM Check if pip installation was successful
138
+ if %errorlevel% neq 0 (
139
+ echo Warning: Failed to install dependencies. You may need to install them manually.
140
+ goto end
141
+ )
142
+
143
+ REM Ask if user wants Sage Attention
144
+ set choice=
145
+ echo Do you want to install any of the following? They speed up generation.
146
+ echo 1) Sage Attention
147
+ echo 2) Flash Attention
148
+ echo 3) BOTH!
149
+ echo 4) No
150
+ set /p choice= "Input Selection: "
151
+
152
+ set both="N"
153
+
154
+ if "!choice!" == "1" (goto triton260)
155
+ if "!choice!"== "2" (goto flash260)
156
+ if "!choice!"== "3" (set both="Y"
157
+ goto triton260)
158
+
159
+ goto requirements
160
+
161
+ :triton260
162
+ REM Sage Attention and Triton for Torch 2.6.0
163
+ "%cd%/venv/Scripts/pip.exe" install "triton-windows<3.3.0" --force-reinstall
164
+ "%cd%/venv/Scripts/pip.exe" install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp%pyv%-cp%pyv%-win_amd64.whl --force-reinstall
165
+
166
+ echo Finishing up installing triton-windows. This requires extraction of libraries into Python Folder...
167
+
168
+ REM Check for python version and download the triton-windows required libs accordingly
169
+ if %pyv% == 310 (
170
+ powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.10.11_include_libs.zip', 'triton-lib.zip')"
171
+ )
172
+
173
+ if %pyv% == 311 (
174
+ powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.11.9_include_libs.zip', 'triton-lib.zip')"
175
+ )
176
+
177
+ if %pyv% == 312 (
178
+ powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.12.7_include_libs.zip', 'triton-lib.zip')"
179
+ )
180
+
181
+ REM Extract the zip into the Python Folder and Delete zip
182
+ powershell Expand-Archive -Path '%cd%\triton-lib.zip' -DestinationPath '%cd%\venv\Scripts\' -force
183
+ del triton-lib.zip
184
+
185
+ if %both% == "Y" (goto flash260)
186
+
187
+ goto requirements
188
+
189
+ :flash260
190
+ REM Install flash-attn.
191
+ "%cd%/venv/Scripts/pip.exe" install "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%%2Bcu126torch2.6.0cxx11abiFALSE-cp%pyv%-cp%pyv%-win_amd64.whl?download=true"
192
+
193
+ :requirements
194
+ echo Installing remaining required packages through pip...
195
+ REM This assumes there's a requirements.txt file in the root
196
+ "%cd%/venv/Scripts/pip.exe" install -r requirements.txt
197
+
198
+ REM Check if pip installation was successful
199
+ if %errorlevel% neq 0 (
200
+ echo Warning: Failed to install dependencies. You may need to install them manually.
201
+ goto end
202
+ )
203
+
204
+ echo Setup complete.
205
+
206
+ :end
207
+ echo Exiting setup script.
208
+ pause
modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # modules/__init__.py
2
+
3
+ # Workaround for the single lora bug. Must not be an empty string.
4
+ DUMMY_LORA_NAME = " "
modules/generators/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .original_generator import OriginalModelGenerator
2
+ from .f1_generator import F1ModelGenerator
3
+ from .video_generator import VideoModelGenerator
4
+ from .video_f1_generator import VideoF1ModelGenerator
5
+ from .original_with_endframe_generator import OriginalWithEndframeModelGenerator
6
+
7
+ def create_model_generator(model_type, **kwargs):
8
+ """
9
+ Create a model generator based on the model type.
10
+
11
+ Args:
12
+ model_type: The type of model to create ("Original", "Original with Endframe", "F1", "Video", or "Video F1")
13
+ **kwargs: Additional arguments to pass to the model generator constructor
14
+
15
+ Returns:
16
+ A model generator instance
17
+
18
+ Raises:
19
+ ValueError: If the model type is not supported
20
+ """
21
+ if model_type == "Original":
22
+ return OriginalModelGenerator(**kwargs)
23
+ elif model_type == "Original with Endframe":
24
+ return OriginalWithEndframeModelGenerator(**kwargs)
25
+ elif model_type == "F1":
26
+ return F1ModelGenerator(**kwargs)
27
+ elif model_type == "Video":
28
+ return VideoModelGenerator(**kwargs)
29
+ elif model_type == "Video F1":
30
+ return VideoF1ModelGenerator(**kwargs)
31
+ else:
32
+ raise ValueError(f"Unsupported model type: {model_type}")
modules/generators/base_generator.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os # required for os.path
3
+ from abc import ABC, abstractmethod
4
+ from diffusers_helper import lora_utils
5
+ from typing import List, Optional
6
+ from pathlib import Path
7
+
8
+ class BaseModelGenerator(ABC):
9
+ """
10
+ Base class for model generators.
11
+ This defines the common interface that all model generators must implement.
12
+ """
13
+
14
+ def __init__(self,
15
+ text_encoder,
16
+ text_encoder_2,
17
+ tokenizer,
18
+ tokenizer_2,
19
+ vae,
20
+ image_encoder,
21
+ feature_extractor,
22
+ high_vram=False,
23
+ prompt_embedding_cache=None,
24
+ settings=None,
25
+ offline=False): # NEW: offline flag
26
+ """
27
+ Initialize the base model generator.
28
+
29
+ Args:
30
+ text_encoder: The text encoder model
31
+ text_encoder_2: The second text encoder model
32
+ tokenizer: The tokenizer for the first text encoder
33
+ tokenizer_2: The tokenizer for the second text encoder
34
+ vae: The VAE model
35
+ image_encoder: The image encoder model
36
+ feature_extractor: The feature extractor
37
+ high_vram: Whether high VRAM mode is enabled
38
+ prompt_embedding_cache: Cache for prompt embeddings
39
+ settings: Application settings
40
+ offline: Whether to run in offline mode for model loading
41
+ """
42
+ self.text_encoder = text_encoder
43
+ self.text_encoder_2 = text_encoder_2
44
+ self.tokenizer = tokenizer
45
+ self.tokenizer_2 = tokenizer_2
46
+ self.vae = vae
47
+ self.image_encoder = image_encoder
48
+ self.feature_extractor = feature_extractor
49
+ self.high_vram = high_vram
50
+ self.prompt_embedding_cache = prompt_embedding_cache or {}
51
+ self.settings = settings
52
+ self.offline = offline
53
+ self.transformer = None
54
+ self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ self.cpu = torch.device("cpu")
56
+
57
+
58
+ @abstractmethod
59
+ def load_model(self):
60
+ """
61
+ Load the transformer model.
62
+ This method should be implemented by each specific model generator.
63
+ """
64
+ pass
65
+
66
+ @abstractmethod
67
+ def get_model_name(self):
68
+ """
69
+ Get the name of the model.
70
+ This method should be implemented by each specific model generator.
71
+ """
72
+ pass
73
+
74
+ @staticmethod
75
+ def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None:
76
+ """
77
+ Reads the commit hash from the refs/main file for a given model in the HF cache.
78
+ Args:
79
+ model_repo_id_for_cache (str): The model ID formatted for cache directory names
80
+ (e.g., "models--lllyasviel--FramePackI2V_HY").
81
+ Returns:
82
+ str: The commit hash if found, otherwise None.
83
+ """
84
+ hf_home_dir = os.environ.get('HF_HOME')
85
+ if not hf_home_dir:
86
+ print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.")
87
+ return None
88
+
89
+ refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main')
90
+ if os.path.exists(refs_main_path):
91
+ try:
92
+ with open(refs_main_path, 'r') as f:
93
+ print(f"Offline mode: Reading snapshot hash from: {refs_main_path}")
94
+ return f.read().strip()
95
+ except Exception as e:
96
+ print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}")
97
+ return None
98
+ else:
99
+ print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.")
100
+ return None
101
+
102
+ def _get_offline_load_path(self) -> str:
103
+ """
104
+ Returns the local snapshot path for offline loading if available.
105
+ Falls back to the default self.model_path if local snapshot can't be found.
106
+ Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses.
107
+ """
108
+ # Ensure necessary attributes are set by the subclass
109
+ if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache:
110
+ print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.")
111
+ # Fallback to model_path if it exists, otherwise None
112
+ return getattr(self, 'model_path', None)
113
+
114
+ if not hasattr(self, 'model_path') or not self.model_path:
115
+ print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.")
116
+ return None
117
+
118
+ snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache)
119
+ hf_home = os.environ.get('HF_HOME')
120
+
121
+ if snapshot_hash and hf_home:
122
+ specific_snapshot_path = os.path.join(
123
+ hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash
124
+ )
125
+ if os.path.isdir(specific_snapshot_path):
126
+ return specific_snapshot_path
127
+
128
+ # If snapshot logic fails or path is not a dir, fallback to the default model path
129
+ return self.model_path
130
+
131
+ def unload_loras(self):
132
+ """
133
+ Unload all LoRAs from the transformer model.
134
+ """
135
+ if self.transformer is not None:
136
+ print(f"Unloading all LoRAs from {self.get_model_name()} model")
137
+ self.transformer = lora_utils.unload_all_loras(self.transformer)
138
+ self.verify_lora_state("After unloading LoRAs")
139
+ import gc
140
+ gc.collect()
141
+ if torch.cuda.is_available():
142
+ torch.cuda.empty_cache()
143
+
144
+ def verify_lora_state(self, label=""):
145
+ """
146
+ Debug function to verify the state of LoRAs in the transformer model.
147
+ """
148
+ if self.transformer is None:
149
+ print(f"[{label}] Transformer is None, cannot verify LoRA state")
150
+ return
151
+
152
+ has_loras = False
153
+ if hasattr(self.transformer, 'peft_config'):
154
+ adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else []
155
+ if adapter_names:
156
+ has_loras = True
157
+ print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}")
158
+ else:
159
+ print(f"[{label}] Transformer has no LoRAs in peft_config")
160
+ else:
161
+ print(f"[{label}] Transformer has no peft_config attribute")
162
+
163
+ # Check for any LoRA modules
164
+ for name, module in self.transformer.named_modules():
165
+ if hasattr(module, 'lora_A') and module.lora_A:
166
+ has_loras = True
167
+ # print(f"[{label}] Found lora_A in module {name}")
168
+ if hasattr(module, 'lora_B') and module.lora_B:
169
+ has_loras = True
170
+ # print(f"[{label}] Found lora_B in module {name}")
171
+
172
+ if not has_loras:
173
+ print(f"[{label}] No LoRA components found in transformer")
174
+
175
+ def move_lora_adapters_to_device(self, target_device):
176
+ """
177
+ Move all LoRA adapters in the transformer model to the specified device.
178
+ This handles the PEFT implementation of LoRA.
179
+ """
180
+ if self.transformer is None:
181
+ return
182
+
183
+ print(f"Moving all LoRA adapters to {target_device}")
184
+
185
+ # First, find all modules with LoRA adapters
186
+ lora_modules = []
187
+ for name, module in self.transformer.named_modules():
188
+ if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
189
+ lora_modules.append((name, module))
190
+
191
+ # Now move all LoRA components to the target device
192
+ for name, module in lora_modules:
193
+ # Get the active adapter name
194
+ active_adapter = module.active_adapter
195
+
196
+ # Move the LoRA layers to the target device
197
+ if active_adapter is not None:
198
+ if isinstance(module.lora_A, torch.nn.ModuleDict):
199
+ # Handle ModuleDict case (PEFT implementation)
200
+ for adapter_name in list(module.lora_A.keys()):
201
+ # Move lora_A
202
+ if adapter_name in module.lora_A:
203
+ module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device)
204
+
205
+ # Move lora_B
206
+ if adapter_name in module.lora_B:
207
+ module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device)
208
+
209
+ # Move scaling
210
+ if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling:
211
+ if isinstance(module.scaling[adapter_name], torch.Tensor):
212
+ module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device)
213
+ else:
214
+ # Handle direct attribute case
215
+ if hasattr(module, 'lora_A') and module.lora_A is not None:
216
+ module.lora_A = module.lora_A.to(target_device)
217
+ if hasattr(module, 'lora_B') and module.lora_B is not None:
218
+ module.lora_B = module.lora_B.to(target_device)
219
+ if hasattr(module, 'scaling') and module.scaling is not None:
220
+ if isinstance(module.scaling, torch.Tensor):
221
+ module.scaling = module.scaling.to(target_device)
222
+
223
+ print(f"Moved all LoRA adapters to {target_device}")
224
+
225
+ def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_names: List[str], lora_values: Optional[List[float]] = None):
226
+ """
227
+ Load LoRAs into the transformer model and applies their weights.
228
+
229
+ Args:
230
+ selected_loras: List of LoRA base names to load (e.g., ["lora_A", "lora_B"]).
231
+ lora_folder: Path to the folder containing the LoRA files.
232
+ lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing.
233
+ lora_values: A list of strength values corresponding to lora_loaded_names.
234
+ """
235
+ self.unload_loras()
236
+
237
+ if not selected_loras:
238
+ print("No LoRAs selected, skipping loading.")
239
+ return
240
+
241
+ lora_dir = Path(lora_folder)
242
+
243
+ adapter_names = []
244
+ strengths = []
245
+
246
+ for idx, lora_base_name in enumerate(selected_loras):
247
+ lora_file = None
248
+ for ext in (".safetensors", ".pt"):
249
+ candidate_path_relative = f"{lora_base_name}{ext}"
250
+ candidate_path_full = lora_dir / candidate_path_relative
251
+ if candidate_path_full.is_file():
252
+ lora_file = candidate_path_relative
253
+ break
254
+
255
+ if not lora_file:
256
+ print(f"Warning: LoRA file for base name '{lora_base_name}' not found; skipping.")
257
+ continue
258
+
259
+ print(f"Loading LoRA from '{lora_file}'...")
260
+
261
+ self.transformer, adapter_name = lora_utils.load_lora(self.transformer, lora_dir, lora_file)
262
+ adapter_names.append(adapter_name)
263
+
264
+ weight = 1.0
265
+ if lora_values:
266
+ try:
267
+ master_list_idx = lora_loaded_names.index(lora_base_name)
268
+ if master_list_idx < len(lora_values):
269
+ weight = float(lora_values[master_list_idx])
270
+ else:
271
+ print(f"Warning: Index mismatch for '{lora_base_name}'. Defaulting to 1.0.")
272
+ except ValueError:
273
+ print(f"Warning: LoRA '{lora_base_name}' not found in master list. Defaulting to 1.0.")
274
+
275
+ strengths.append(weight)
276
+
277
+ if adapter_names:
278
+ print(f"Activating adapters: {adapter_names} with strengths: {strengths}")
279
+ lora_utils.set_adapters(self.transformer, adapter_names, strengths)
280
+
281
+ self.verify_lora_state("After completing load_loras")
modules/generators/f1_generator.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os # for offline loading path
3
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
4
+ from diffusers_helper.memory import DynamicSwapInstaller
5
+ from .base_generator import BaseModelGenerator
6
+
7
+ class F1ModelGenerator(BaseModelGenerator):
8
+ """
9
+ Model generator for the F1 HunyuanVideo model.
10
+ """
11
+
12
+ def __init__(self, **kwargs):
13
+ """
14
+ Initialize the F1 model generator.
15
+ """
16
+ super().__init__(**kwargs)
17
+ self.model_name = "F1"
18
+ self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503'
19
+ self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503"
20
+
21
+ def get_model_name(self):
22
+ """
23
+ Get the name of the model.
24
+ """
25
+ return self.model_name
26
+
27
+ def load_model(self):
28
+ """
29
+ Load the F1 transformer model.
30
+ If offline mode is True, attempts to load from a local snapshot.
31
+ """
32
+ print(f"Loading {self.model_name} Transformer...")
33
+
34
+ path_to_load = self.model_path # Initialize with the default path
35
+
36
+ if self.offline:
37
+ path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
38
+
39
+ # Create the transformer model
40
+ self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
41
+ path_to_load,
42
+ torch_dtype=torch.bfloat16
43
+ ).cpu()
44
+
45
+ # Configure the model
46
+ self.transformer.eval()
47
+ self.transformer.to(dtype=torch.bfloat16)
48
+ self.transformer.requires_grad_(False)
49
+
50
+ # Set up dynamic swap if not in high VRAM mode
51
+ if not self.high_vram:
52
+ DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
53
+ else:
54
+ # In high VRAM mode, move the entire model to GPU
55
+ self.transformer.to(device=self.gpu)
56
+
57
+ print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
58
+ return self.transformer
59
+
60
+ def prepare_history_latents(self, height, width):
61
+ """
62
+ Prepare the history latents tensor for the F1 model.
63
+
64
+ Args:
65
+ height: The height of the image
66
+ width: The width of the image
67
+
68
+ Returns:
69
+ The initialized history latents tensor
70
+ """
71
+ return torch.zeros(
72
+ size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
73
+ dtype=torch.float32
74
+ ).cpu()
75
+
76
+ def initialize_with_start_latent(self, history_latents, start_latent, is_real_image_latent):
77
+ """
78
+ Initialize the history latents with the start latent for the F1 model.
79
+
80
+ Args:
81
+ history_latents: The history latents
82
+ start_latent: The start latent
83
+ is_real_image_latent: Whether the start latent came from a real input image or is a synthetic noise
84
+
85
+ Returns:
86
+ The initialized history latents
87
+ """
88
+ # Add the start frame to history_latents
89
+ if is_real_image_latent:
90
+ return torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
91
+ # After prepare_history_latents, history_latents (initialized with zeros)
92
+ # already has the required 19 entries for initial clean latents
93
+ return history_latents
94
+
95
+ def get_latent_paddings(self, total_latent_sections):
96
+ """
97
+ Get the latent paddings for the F1 model.
98
+
99
+ Args:
100
+ total_latent_sections: The total number of latent sections
101
+
102
+ Returns:
103
+ A list of latent paddings
104
+ """
105
+ # F1 model uses a fixed approach with just 0 for last section and 1 for others
106
+ return [1] * (total_latent_sections - 1) + [0]
107
+
108
+ def prepare_indices(self, latent_padding_size, latent_window_size):
109
+ """
110
+ Prepare the indices for the F1 model.
111
+
112
+ Args:
113
+ latent_padding_size: The size of the latent padding
114
+ latent_window_size: The size of the latent window
115
+
116
+ Returns:
117
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
118
+ """
119
+ # F1 model uses a different indices approach
120
+ # latent_window_sizeが4.5の場合は特別に5を使用
121
+ effective_window_size = 5 if latent_window_size == 4.5 else int(latent_window_size)
122
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
123
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
124
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
125
+
126
+ return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
127
+
128
+ def prepare_clean_latents(self, start_latent, history_latents):
129
+ """
130
+ Prepare the clean latents for the F1 model.
131
+
132
+ Args:
133
+ start_latent: The start latent
134
+ history_latents: The history latents
135
+
136
+ Returns:
137
+ A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
138
+ """
139
+ # For F1, we take the last frames for clean latents
140
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2)
141
+ # For F1, we prepend the start latent to clean_latents_1x
142
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
143
+
144
+ return clean_latents, clean_latents_2x, clean_latents_4x
145
+
146
+ def update_history_latents(self, history_latents, generated_latents):
147
+ """
148
+ Update the history latents with the generated latents for the F1 model.
149
+
150
+ Args:
151
+ history_latents: The history latents
152
+ generated_latents: The generated latents
153
+
154
+ Returns:
155
+ The updated history latents
156
+ """
157
+ # For F1, we append new frames to the end
158
+ return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
159
+
160
+ def get_real_history_latents(self, history_latents, total_generated_latent_frames):
161
+ """
162
+ Get the real history latents for the F1 model.
163
+
164
+ Args:
165
+ history_latents: The history latents
166
+ total_generated_latent_frames: The total number of generated latent frames
167
+
168
+ Returns:
169
+ The real history latents
170
+ """
171
+ # For F1, we take frames from the end
172
+ return history_latents[:, :, -total_generated_latent_frames:, :, :]
173
+
174
+ def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
175
+ """
176
+ Update the history pixels with the current pixels for the F1 model.
177
+
178
+ Args:
179
+ history_pixels: The history pixels
180
+ current_pixels: The current pixels
181
+ overlapped_frames: The number of overlapped frames
182
+
183
+ Returns:
184
+ The updated history pixels
185
+ """
186
+ from diffusers_helper.utils import soft_append_bcthw
187
+ # For F1 model, history_pixels is first, current_pixels is second
188
+ return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
189
+
190
+ def get_section_latent_frames(self, latent_window_size, is_last_section):
191
+ """
192
+ Get the number of section latent frames for the F1 model.
193
+
194
+ Args:
195
+ latent_window_size: The size of the latent window
196
+ is_last_section: Whether this is the last section
197
+
198
+ Returns:
199
+ The number of section latent frames
200
+ """
201
+ return latent_window_size * 2
202
+
203
+ def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
204
+ """
205
+ Get the current pixels for the F1 model.
206
+
207
+ Args:
208
+ real_history_latents: The real history latents
209
+ section_latent_frames: The number of section latent frames
210
+ vae: The VAE model
211
+
212
+ Returns:
213
+ The current pixels
214
+ """
215
+ from diffusers_helper.hunyuan import vae_decode
216
+ # For F1, we take frames from the end
217
+ return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
218
+
219
+ def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
220
+ """
221
+ Format the position description for the F1 model.
222
+
223
+ Args:
224
+ total_generated_latent_frames: The total number of generated latent frames
225
+ current_pos: The current position in seconds
226
+ original_pos: The original position in seconds
227
+ current_prompt: The current prompt
228
+
229
+ Returns:
230
+ The formatted position description
231
+ """
232
+ return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
233
+ f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
234
+ f'Current position: {current_pos:.2f}s. '
235
+ f'using prompt: {current_prompt[:256]}...')
modules/generators/original_generator.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os # for offline loading path
3
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
4
+ from diffusers_helper.memory import DynamicSwapInstaller
5
+ from .base_generator import BaseModelGenerator
6
+
7
+ class OriginalModelGenerator(BaseModelGenerator):
8
+ """
9
+ Model generator for the Original HunyuanVideo model.
10
+ """
11
+
12
+ def __init__(self, **kwargs):
13
+ """
14
+ Initialize the Original model generator.
15
+ """
16
+ super().__init__(**kwargs)
17
+ self.model_name = "Original"
18
+ self.model_path = 'lllyasviel/FramePackI2V_HY'
19
+ self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY"
20
+
21
+ def get_model_name(self):
22
+ """
23
+ Get the name of the model.
24
+ """
25
+ return self.model_name
26
+
27
+ def load_model(self):
28
+ """
29
+ Load the Original transformer model.
30
+ If offline mode is True, attempts to load from a local snapshot.
31
+ """
32
+ print(f"Loading {self.model_name} Transformer...")
33
+
34
+ path_to_load = self.model_path # Initialize with the default path
35
+
36
+ if self.offline:
37
+ path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
38
+
39
+ # Create the transformer model
40
+ self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
41
+ path_to_load,
42
+ torch_dtype=torch.bfloat16
43
+ ).cpu()
44
+
45
+ # Configure the model
46
+ self.transformer.eval()
47
+ self.transformer.to(dtype=torch.bfloat16)
48
+ self.transformer.requires_grad_(False)
49
+
50
+ # Set up dynamic swap if not in high VRAM mode
51
+ if not self.high_vram:
52
+ DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
53
+ else:
54
+ # In high VRAM mode, move the entire model to GPU
55
+ self.transformer.to(device=self.gpu)
56
+
57
+ print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
58
+ return self.transformer
59
+
60
+ def prepare_history_latents(self, height, width):
61
+ """
62
+ Prepare the history latents tensor for the Original model.
63
+
64
+ Args:
65
+ height: The height of the image
66
+ width: The width of the image
67
+
68
+ Returns:
69
+ The initialized history latents tensor
70
+ """
71
+ return torch.zeros(
72
+ size=(1, 16, 1 + 2 + 16, height // 8, width // 8),
73
+ dtype=torch.float32
74
+ ).cpu()
75
+
76
+ def get_latent_paddings(self, total_latent_sections):
77
+ """
78
+ Get the latent paddings for the Original model.
79
+
80
+ Args:
81
+ total_latent_sections: The total number of latent sections
82
+
83
+ Returns:
84
+ A list of latent paddings
85
+ """
86
+ # Original model uses reversed latent paddings
87
+ if total_latent_sections > 4:
88
+ return [3] + [2] * (total_latent_sections - 3) + [1, 0]
89
+ else:
90
+ return list(reversed(range(total_latent_sections)))
91
+
92
+ def prepare_indices(self, latent_padding_size, latent_window_size):
93
+ """
94
+ Prepare the indices for the Original model.
95
+
96
+ Args:
97
+ latent_padding_size: The size of the latent padding
98
+ latent_window_size: The size of the latent window
99
+
100
+ Returns:
101
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
102
+ """
103
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
104
+ clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
105
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
106
+
107
+ return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices
108
+
109
+ def prepare_clean_latents(self, start_latent, history_latents):
110
+ """
111
+ Prepare the clean latents for the Original model.
112
+
113
+ Args:
114
+ start_latent: The start latent
115
+ history_latents: The history latents
116
+
117
+ Returns:
118
+ A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
119
+ """
120
+ clean_latents_pre = start_latent.to(history_latents)
121
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
122
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
123
+
124
+ return clean_latents, clean_latents_2x, clean_latents_4x
125
+
126
+ def update_history_latents(self, history_latents, generated_latents):
127
+ """
128
+ Update the history latents with the generated latents for the Original model.
129
+
130
+ Args:
131
+ history_latents: The history latents
132
+ generated_latents: The generated latents
133
+
134
+ Returns:
135
+ The updated history latents
136
+ """
137
+ # For Original model, we prepend the generated latents
138
+ return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
139
+
140
+ def get_real_history_latents(self, history_latents, total_generated_latent_frames):
141
+ """
142
+ Get the real history latents for the Original model.
143
+
144
+ Args:
145
+ history_latents: The history latents
146
+ total_generated_latent_frames: The total number of generated latent frames
147
+
148
+ Returns:
149
+ The real history latents
150
+ """
151
+ return history_latents[:, :, :total_generated_latent_frames, :, :]
152
+
153
+ def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
154
+ """
155
+ Update the history pixels with the current pixels for the Original model.
156
+
157
+ Args:
158
+ history_pixels: The history pixels
159
+ current_pixels: The current pixels
160
+ overlapped_frames: The number of overlapped frames
161
+
162
+ Returns:
163
+ The updated history pixels
164
+ """
165
+ from diffusers_helper.utils import soft_append_bcthw
166
+ # For Original model, current_pixels is first, history_pixels is second
167
+ return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
168
+
169
+ def get_section_latent_frames(self, latent_window_size, is_last_section):
170
+ """
171
+ Get the number of section latent frames for the Original model.
172
+
173
+ Args:
174
+ latent_window_size: The size of the latent window
175
+ is_last_section: Whether this is the last section
176
+
177
+ Returns:
178
+ The number of section latent frames
179
+ """
180
+ return latent_window_size * 2
181
+
182
+ def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
183
+ """
184
+ Get the current pixels for the Original model.
185
+
186
+ Args:
187
+ real_history_latents: The real history latents
188
+ section_latent_frames: The number of section latent frames
189
+ vae: The VAE model
190
+
191
+ Returns:
192
+ The current pixels
193
+ """
194
+ from diffusers_helper.hunyuan import vae_decode
195
+ return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
196
+
197
+ def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
198
+ """
199
+ Format the position description for the Original model.
200
+
201
+ Args:
202
+ total_generated_latent_frames: The total number of generated latent frames
203
+ current_pos: The current position in seconds
204
+ original_pos: The original position in seconds
205
+ current_prompt: The current prompt
206
+
207
+ Returns:
208
+ The formatted position description
209
+ """
210
+ return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
211
+ f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
212
+ f'Current position: {current_pos:.2f}s (original: {original_pos:.2f}s). '
213
+ f'using prompt: {current_prompt[:256]}...')
modules/generators/original_with_endframe_generator.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .original_generator import OriginalModelGenerator
2
+
3
+ class OriginalWithEndframeModelGenerator(OriginalModelGenerator):
4
+ """
5
+ Model generator for the Original HunyuanVideo model with end frame support.
6
+ This extends the Original model with the ability to guide generation toward a specified end frame.
7
+ """
8
+
9
+ def __init__(self, **kwargs):
10
+ """
11
+ Initialize the Original with Endframe model generator.
12
+ """
13
+ super().__init__(**kwargs)
14
+ self.model_name = "Original with Endframe"
15
+ # Inherits everything else from OriginalModelGenerator
modules/generators/video_base_generator.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import math
5
+ import decord
6
+ from tqdm import tqdm
7
+ import pathlib
8
+ from PIL import Image
9
+
10
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
11
+ from diffusers_helper.memory import DynamicSwapInstaller
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from diffusers_helper.hunyuan import vae_encode, vae_decode
15
+ from .base_generator import BaseModelGenerator
16
+
17
+ class VideoBaseModelGenerator(BaseModelGenerator):
18
+ """
19
+ Model generator for the Video extension of the Original HunyuanVideo model.
20
+ This generator accepts video input instead of a single image.
21
+ """
22
+
23
+ def __init__(self, **kwargs):
24
+ """
25
+ Initialize the Video model generator.
26
+ """
27
+ super().__init__(**kwargs)
28
+ self.model_name = None # Subclass Model Specific
29
+ self.model_path = None # Subclass Model Specific
30
+ self.model_repo_id_for_cache = None # Subclass Model Specific
31
+ self.full_video_latents = None # For context, set by worker() when available
32
+ self.resolution = 640 # Default resolution
33
+ self.no_resize = False # Default to resize
34
+ self.vae_batch_size = 16 # Default VAE batch size
35
+
36
+ # Import decord and tqdm here to avoid import errors if not installed
37
+ try:
38
+ import decord
39
+ from tqdm import tqdm
40
+ self.decord = decord
41
+ self.tqdm = tqdm
42
+ except ImportError:
43
+ print("Warning: decord or tqdm not installed. Video processing will not work.")
44
+ self.decord = None
45
+ self.tqdm = None
46
+
47
+ def get_model_name(self):
48
+ """
49
+ Get the name of the model.
50
+ """
51
+ return self.model_name
52
+
53
+ def load_model(self):
54
+ """
55
+ Load the Video transformer model.
56
+ If offline mode is True, attempts to load from a local snapshot.
57
+ """
58
+ print(f"Loading {self.model_name} Transformer...")
59
+
60
+ path_to_load = self.model_path # Initialize with the default path
61
+
62
+ if self.offline:
63
+ path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator
64
+
65
+ # Create the transformer model
66
+ self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
67
+ path_to_load,
68
+ torch_dtype=torch.bfloat16
69
+ ).cpu()
70
+
71
+ # Configure the model
72
+ self.transformer.eval()
73
+ self.transformer.to(dtype=torch.bfloat16)
74
+ self.transformer.requires_grad_(False)
75
+
76
+ # Set up dynamic swap if not in high VRAM mode
77
+ if not self.high_vram:
78
+ DynamicSwapInstaller.install_model(self.transformer, device=self.gpu)
79
+ else:
80
+ # In high VRAM mode, move the entire model to GPU
81
+ self.transformer.to(device=self.gpu)
82
+
83
+ print(f"{self.model_name} Transformer Loaded from {path_to_load}.")
84
+ return self.transformer
85
+
86
+ def min_real_frames_to_encode(self, real_frames_available_count):
87
+ """
88
+ Minimum number of real frames to encode
89
+ is the maximum number of real frames used for generation context.
90
+
91
+ The number of latents could be calculated as below for video F1, but keeping it simple for now
92
+ by hardcoding the Video F1 value at max_latents_used_for_context = 27.
93
+
94
+ # Calculate the number of latent frames to encode from the end of the input video
95
+ num_frames_to_encode_from_end = 1 # Default minimum
96
+ if model_type == "Video":
97
+ # Max needed is 1 (clean_latent_pre) + 2 (max 2x) + 16 (max 4x) = 19
98
+ num_frames_to_encode_from_end = 19
99
+ elif model_type == "Video F1":
100
+ ui_num_cleaned_frames = job_params.get('num_cleaned_frames', 5)
101
+ # Max effective_clean_frames based on VideoF1ModelGenerator's logic.
102
+ # Max num_clean_frames from UI is 10 (modules/interface.py).
103
+ # Max effective_clean_frames = 10 - 1 = 9.
104
+ # total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames
105
+ # Max needed = 16 (max 4x) + 2 (max 2x) + 9 (max effective_clean_frames) = 27
106
+ num_frames_to_encode_from_end = 27
107
+
108
+ Note: 27 latents ~ 108 real frames = 3.6 seconds at 30 FPS.
109
+ Note: 19 latents ~ 76 real frames ~ 2.5 seconds at 30 FPS.
110
+ """
111
+
112
+ max_latents_used_for_context = 27
113
+ if self.get_model_name() == "Video":
114
+ max_latents_used_for_context = 27 # Weird results on 19
115
+ elif self.get_model_name() == "Video F1":
116
+ max_latents_used_for_context = 27 # Enough for even Video F1 with cleaned_frames input of 10
117
+ else:
118
+ print("======================================================")
119
+ print(f" ***** Warning: Unsupported video extension model type: {self.get_model_name()}.")
120
+ print( " ***** Using default max latents {max_latents_used_for_context} for context.")
121
+ print( " ***** Please report to the developers if you see this message:")
122
+ print( " ***** Discord: https://discord.gg/8Z2c3a4 or GitHub: https://github.com/colinurbs/FramePack-Studio")
123
+ print("======================================================")
124
+ # Probably better to press on with Video F1 max vs exception?
125
+ # raise ValueError(f"Unsupported video extension model type: {self.get_model_name()}")
126
+
127
+ latent_size_factor = 4 # real frames to latent frames conversion factor
128
+ max_real_frames_used_for_context = max_latents_used_for_context * latent_size_factor
129
+
130
+ # Shortest of available frames and max frames used for context
131
+ trimmed_real_frames_count = min(real_frames_available_count, max_real_frames_used_for_context)
132
+ if trimmed_real_frames_count < real_frames_available_count:
133
+ print(f"Truncating video frames from {real_frames_available_count} to {trimmed_real_frames_count}, enough to populate context")
134
+
135
+ # Truncate to nearest latent size (multiple of 4)
136
+ frames_to_encode_count = (trimmed_real_frames_count // latent_size_factor) * latent_size_factor
137
+ if frames_to_encode_count != trimmed_real_frames_count:
138
+ print(f"Truncating video frames from {trimmed_real_frames_count} to {frames_to_encode_count}, for latent size compatibility")
139
+
140
+ return frames_to_encode_count
141
+
142
+ def extract_video_frames(self, is_for_encode, video_path, resolution, no_resize=False, input_files_dir=None):
143
+ """
144
+ Extract real frames from a video, resized and center cropped as numpy array (T, H, W, C).
145
+
146
+ Args:
147
+ is_for_encode: If True, results are capped at maximum frames used for context, and aligned to 4-frame latent requirement.
148
+ video_path: Path to the input video file.
149
+ resolution: Target resolution for resizing frames.
150
+ no_resize: Whether to use the original video resolution.
151
+ input_files_dir: Directory for input files that won't be cleaned up.
152
+
153
+ Returns:
154
+ A tuple containing:
155
+ - input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C)
156
+ - fps: Frames per second of the input video
157
+ - target_height: Target height of the video
158
+ - target_width: Target width of the video
159
+ """
160
+ def time_millis():
161
+ import time
162
+ return time.perf_counter() * 1000.0 # Convert seconds to milliseconds
163
+
164
+ encode_start_time_millis = time_millis()
165
+
166
+ # Normalize video path for Windows compatibility
167
+ video_path = str(pathlib.Path(video_path).resolve())
168
+ print(f"Processing video: {video_path}")
169
+
170
+ # Check if the video is in the temp directory and if we have an input_files_dir
171
+ if input_files_dir and "temp" in video_path:
172
+ # Check if there's a copy of this video in the input_files_dir
173
+ filename = os.path.basename(video_path)
174
+ input_file_path = os.path.join(input_files_dir, filename)
175
+
176
+ # If the file exists in input_files_dir, use that instead
177
+ if os.path.exists(input_file_path):
178
+ print(f"Using video from input_files_dir: {input_file_path}")
179
+ video_path = input_file_path
180
+ else:
181
+ # If not, copy it to input_files_dir to prevent it from being deleted
182
+ try:
183
+ from diffusers_helper.utils import generate_timestamp
184
+ safe_filename = f"{generate_timestamp()}_{filename}"
185
+ input_file_path = os.path.join(input_files_dir, safe_filename)
186
+ import shutil
187
+ shutil.copy2(video_path, input_file_path)
188
+ print(f"Copied video to input_files_dir: {input_file_path}")
189
+ video_path = input_file_path
190
+ except Exception as e:
191
+ print(f"Error copying video to input_files_dir: {e}")
192
+
193
+ try:
194
+ # Load video and get FPS
195
+ print("Initializing VideoReader...")
196
+ vr = decord.VideoReader(video_path)
197
+ fps = vr.get_avg_fps() # Get input video FPS
198
+ num_real_frames = len(vr)
199
+ print(f"Video loaded: {num_real_frames} frames, FPS: {fps}")
200
+
201
+ # Read frames
202
+ print("Reading video frames...")
203
+
204
+ total_frames_in_video_file = len(vr)
205
+ if is_for_encode:
206
+ print(f"Using minimum real frames to encode: {self.min_real_frames_to_encode(total_frames_in_video_file)}")
207
+ num_real_frames = self.min_real_frames_to_encode(total_frames_in_video_file)
208
+ # else left as all frames -- len(vr) with no regard for trimming or latent alignment
209
+
210
+ # RT_BORG: Retaining this commented code for reference.
211
+ # pftq encoder discarded truncated frames from the end of the video.
212
+ # frames = vr.get_batch(range(num_real_frames)).asnumpy() # Shape: (num_real_frames, height, width, channels)
213
+
214
+ # RT_BORG: Retaining this commented code for reference.
215
+ # pftq retained the entire encoded video.
216
+ # Truncate to nearest latent size (multiple of 4)
217
+ # latent_size_factor = 4
218
+ # num_frames = (num_real_frames // latent_size_factor) * latent_size_factor
219
+ # if num_frames != num_real_frames:
220
+ # print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility")
221
+ # num_real_frames = num_frames
222
+
223
+ # Discard truncated frames from the beginning of the video, retaining the last num_real_frames
224
+ # This ensures a smooth transition from the input video to the generated video
225
+ start_frame_index = total_frames_in_video_file - num_real_frames
226
+ frame_indices_to_extract = range(start_frame_index, total_frames_in_video_file)
227
+ frames = vr.get_batch(frame_indices_to_extract).asnumpy() # Shape: (num_real_frames, height, width, channels)
228
+
229
+ print(f"Frames read: {frames.shape}")
230
+
231
+ # Get native video resolution
232
+ native_height, native_width = frames.shape[1], frames.shape[2]
233
+ print(f"Native video resolution: {native_width}x{native_height}")
234
+
235
+ # Use native resolution if height/width not specified, otherwise use provided values
236
+ target_height = native_height
237
+ target_width = native_width
238
+
239
+ # Adjust to nearest bucket for model compatibility
240
+ if not no_resize:
241
+ target_height, target_width = find_nearest_bucket(target_height, target_width, resolution=resolution)
242
+ print(f"Adjusted resolution: {target_width}x{target_height}")
243
+ else:
244
+ print(f"Using native resolution without resizing: {target_width}x{target_height}")
245
+
246
+ # Preprocess input frames to match desired resolution
247
+ input_frames_resized_np = []
248
+ for i, frame in tqdm(enumerate(frames), desc="Processing Video Frames", total=num_real_frames, mininterval=0.1):
249
+ frame_np = resize_and_center_crop(frame, target_width=target_width, target_height=target_height)
250
+ input_frames_resized_np.append(frame_np)
251
+ input_frames_resized_np = np.stack(input_frames_resized_np) # Shape: (num_real_frames, height, width, channels)
252
+ print(f"Frames preprocessed: {input_frames_resized_np.shape}")
253
+
254
+ resized_frames_time_millis = time_millis()
255
+ if (False): # We really need a logger
256
+ print("======================================================")
257
+ memory_bytes = input_frames_resized_np.nbytes
258
+ memory_kb = memory_bytes / 1024
259
+ memory_mb = memory_kb / 1024
260
+ print(f" ***** input_frames_resized_np: {input_frames_resized_np.shape}")
261
+ print(f" ***** Memory usage: {int(memory_mb)} MB")
262
+ duration_ms = resized_frames_time_millis - encode_start_time_millis
263
+ print(f" ***** Time taken to process frames tensor: {duration_ms / 1000.0:.2f} seconds")
264
+ print("======================================================")
265
+
266
+ return input_frames_resized_np, fps, target_height, target_width
267
+ except Exception as e:
268
+ print(f"Error in extract_video_frames: {str(e)}")
269
+ raise
270
+
271
+ # RT_BORG: video_encode produce and return end_of_input_video_latent and end_of_input_video_image_np
272
+ # which are not needed for Video models without end frame processing.
273
+ # But these should be inexpensive and it's easier to keep the code uniform.
274
+ @torch.no_grad()
275
+ def video_encode(self, video_path, resolution, no_resize=False, vae_batch_size=16, device=None, input_files_dir=None):
276
+ """
277
+ Encode a video into latent representations using the VAE.
278
+
279
+ Args:
280
+ video_path: Path to the input video file.
281
+ resolution: Target resolution for resizing frames.
282
+ no_resize: Whether to use the original video resolution.
283
+ vae_batch_size: Number of frames to process per batch.
284
+ device: Device for computation (e.g., "cuda").
285
+ input_files_dir: Directory for input files that won't be cleaned up.
286
+
287
+ Returns:
288
+ A tuple containing:
289
+ - start_latent: Latent of the first frame
290
+ - input_image_np: First frame as numpy array
291
+ - history_latents: Latents of all frames
292
+ - fps: Frames per second of the input video
293
+ - target_height: Target height of the video
294
+ - target_width: Target width of the video
295
+ - input_video_pixels: Video frames as tensor
296
+ - end_of_input_video_image_np: Last frame as numpy array
297
+ - input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C)
298
+ """
299
+ encoding = True # Flag to indicate this is for encoding
300
+ input_frames_resized_np, fps, target_height, target_width = self.extract_video_frames(encoding, video_path, resolution, no_resize, input_files_dir)
301
+
302
+ try:
303
+ if device is None:
304
+ device = self.gpu
305
+
306
+ # Check CUDA availability and fallback to CPU if needed
307
+ if device == "cuda" and not torch.cuda.is_available():
308
+ print("CUDA is not available, falling back to CPU")
309
+ device = "cpu"
310
+
311
+ # Save first frame for CLIP vision encoding
312
+ input_image_np = input_frames_resized_np[0]
313
+ end_of_input_video_image_np = input_frames_resized_np[-1]
314
+
315
+ # Convert to tensor and normalize to [-1, 1]
316
+ print("Converting frames to tensor...")
317
+ frames_pt = torch.from_numpy(input_frames_resized_np).float() / 127.5 - 1
318
+ frames_pt = frames_pt.permute(0, 3, 1, 2) # Shape: (num_real_frames, channels, height, width)
319
+ frames_pt = frames_pt.unsqueeze(0) # Shape: (1, num_real_frames, channels, height, width)
320
+ frames_pt = frames_pt.permute(0, 2, 1, 3, 4) # Shape: (1, channels, num_real_frames, height, width)
321
+ print(f"Tensor shape: {frames_pt.shape}")
322
+
323
+ # Save pixel frames for use in worker
324
+ input_video_pixels = frames_pt.cpu()
325
+
326
+ # Move to device
327
+ print(f"Moving tensor to device: {device}")
328
+ frames_pt = frames_pt.to(device)
329
+ print("Tensor moved to device")
330
+
331
+ # Move VAE to device
332
+ print(f"Moving VAE to device: {device}")
333
+ self.vae.to(device)
334
+ print("VAE moved to device")
335
+
336
+ # Encode frames in batches
337
+ print(f"Encoding input video frames in VAE batch size {vae_batch_size}")
338
+ latents = []
339
+ self.vae.eval()
340
+ with torch.no_grad():
341
+ frame_count = frames_pt.shape[2]
342
+ step_count = math.ceil(frame_count / vae_batch_size)
343
+ for i in tqdm(range(0, frame_count, vae_batch_size), desc="Encoding video frames", total=step_count, mininterval=0.1):
344
+ batch = frames_pt[:, :, i:i + vae_batch_size] # Shape: (1, channels, batch_size, height, width)
345
+ try:
346
+ # Log GPU memory before encoding
347
+ if device == "cuda":
348
+ free_mem = torch.cuda.memory_allocated() / 1024**3
349
+ batch_latent = vae_encode(batch, self.vae)
350
+ # Synchronize CUDA to catch issues
351
+ if device == "cuda":
352
+ torch.cuda.synchronize()
353
+ latents.append(batch_latent)
354
+ except RuntimeError as e:
355
+ print(f"Error during VAE encoding: {str(e)}")
356
+ if device == "cuda" and "out of memory" in str(e).lower():
357
+ print("CUDA out of memory, try reducing vae_batch_size or using CPU")
358
+ raise
359
+
360
+ # Concatenate latents
361
+ print("Concatenating latents...")
362
+ history_latents = torch.cat(latents, dim=2) # Shape: (1, channels, frames, height//8, width//8)
363
+ print(f"History latents shape: {history_latents.shape}")
364
+
365
+ # Get first frame's latent
366
+ start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8)
367
+ print(f"Start latent shape: {start_latent.shape}")
368
+
369
+ if (False): # We really need a logger
370
+ print("======================================================")
371
+ memory_bytes = history_latents.nbytes
372
+ memory_kb = memory_bytes / 1024
373
+ memory_mb = memory_kb / 1024
374
+ print(f" ***** history_latents: {history_latents.shape}")
375
+ print(f" ***** Memory usage: {int(memory_mb)} MB")
376
+ print("======================================================")
377
+
378
+ # Move VAE back to CPU to free GPU memory
379
+ if device == "cuda":
380
+ self.vae.to(self.cpu)
381
+ torch.cuda.empty_cache()
382
+ print("VAE moved back to CPU, CUDA cache cleared")
383
+
384
+ return start_latent, input_image_np, history_latents, fps, target_height, target_width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np
385
+
386
+ except Exception as e:
387
+ print(f"Error in video_encode: {str(e)}")
388
+ raise
389
+
390
+ # RT_BORG: Currently history_latents is initialized within worker() for all Video models as history_latents = video_latents
391
+ # So it is a coding error to call prepare_history_latents() here.
392
+ # Leaving in place as we will likely use it post-refactoring.
393
+ def prepare_history_latents(self, height, width):
394
+ """
395
+ Prepare the history latents tensor for the Video model.
396
+
397
+ Args:
398
+ height: The height of the image
399
+ width: The width of the image
400
+
401
+ Returns:
402
+ The initialized history latents tensor
403
+ """
404
+ raise TypeError(
405
+ f"Error: '{self.__class__.__name__}.prepare_history_latents' should not be called "
406
+ "on the Video models. history_latents should be initialized within worker() for all Video models "
407
+ "as history_latents = video_latents."
408
+ )
409
+
410
+ def prepare_indices(self, latent_padding_size, latent_window_size):
411
+ """
412
+ Prepare the indices for the Video model.
413
+
414
+ Args:
415
+ latent_padding_size: The size of the latent padding
416
+ latent_window_size: The size of the latent window
417
+
418
+ Returns:
419
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices)
420
+ """
421
+ raise TypeError(
422
+ f"Error: '{self.__class__.__name__}.prepare_indices' should not be called "
423
+ "on the Video models. Currently video models each have a combined method: <model>_prepare_clean_latents_and_indices."
424
+ )
425
+
426
+ def set_full_video_latents(self, video_latents):
427
+ """
428
+ Set the full video latents for context.
429
+
430
+ Args:
431
+ video_latents: The full video latents
432
+ """
433
+ self.full_video_latents = video_latents
434
+
435
+ def prepare_clean_latents(self, start_latent, history_latents):
436
+ """
437
+ Prepare the clean latents for the Video model.
438
+
439
+ Args:
440
+ start_latent: The start latent
441
+ history_latents: The history latents
442
+
443
+ Returns:
444
+ A tuple of (clean_latents, clean_latents_2x, clean_latents_4x)
445
+ """
446
+ raise TypeError(
447
+ f"Error: '{self.__class__.__name__}.prepare_indices' should not be called "
448
+ "on the Video models. Currently video models each have a combined method: <model>_prepare_clean_latents_and_indices."
449
+ )
450
+
451
+ def get_section_latent_frames(self, latent_window_size, is_last_section):
452
+ """
453
+ Get the number of section latent frames for the Video model.
454
+
455
+ Args:
456
+ latent_window_size: The size of the latent window
457
+ is_last_section: Whether this is the last section
458
+
459
+ Returns:
460
+ The number of section latent frames
461
+ """
462
+ return latent_window_size * 2
463
+
464
+ def combine_videos(self, source_video_path, generated_video_path, output_path):
465
+ """
466
+ Combine the source video with the generated video side by side.
467
+
468
+ Args:
469
+ source_video_path: Path to the source video
470
+ generated_video_path: Path to the generated video
471
+ output_path: Path to save the combined video
472
+
473
+ Returns:
474
+ Path to the combined video
475
+ """
476
+ try:
477
+ import os
478
+ import subprocess
479
+
480
+ print(f"Combining source video {source_video_path} with generated video {generated_video_path}")
481
+
482
+ # Get the ffmpeg executable from the VideoProcessor class
483
+ from modules.toolbox.toolbox_processor import VideoProcessor
484
+ from modules.toolbox.message_manager import MessageManager
485
+
486
+ # Create a message manager for logging
487
+ message_manager = MessageManager()
488
+
489
+ # Import settings from main module
490
+ try:
491
+ from __main__ import settings
492
+ video_processor = VideoProcessor(message_manager, settings.settings)
493
+ except ImportError:
494
+ # Fallback to creating a new settings object
495
+ from modules.settings import Settings
496
+ settings = Settings()
497
+ video_processor = VideoProcessor(message_manager, settings.settings)
498
+
499
+ # Get the ffmpeg executable
500
+ ffmpeg_exe = video_processor.ffmpeg_exe
501
+
502
+ if not ffmpeg_exe:
503
+ print("FFmpeg executable not found. Cannot combine videos.")
504
+ return None
505
+
506
+ print(f"Using ffmpeg at: {ffmpeg_exe}")
507
+
508
+ # Create a temporary directory for the filter script
509
+ import tempfile
510
+ temp_dir = tempfile.gettempdir()
511
+ filter_script_path = os.path.join(temp_dir, f"filter_script_{os.path.basename(output_path)}.txt")
512
+
513
+ # Get video dimensions using ffprobe
514
+ def get_video_info(video_path):
515
+ cmd = [
516
+ ffmpeg_exe, "-i", video_path,
517
+ "-hide_banner", "-loglevel", "error"
518
+ ]
519
+
520
+ # Run ffmpeg to get video info (it will fail but output info to stderr)
521
+ result = subprocess.run(cmd, capture_output=True, text=True)
522
+
523
+ # Parse the output to get dimensions
524
+ width = height = None
525
+ for line in result.stderr.split('\n'):
526
+ if 'Video:' in line:
527
+ # Look for dimensions like 640x480
528
+ import re
529
+ match = re.search(r'(\d+)x(\d+)', line)
530
+ if match:
531
+ width = int(match.group(1))
532
+ height = int(match.group(2))
533
+ break
534
+
535
+ return width, height
536
+
537
+ # Get dimensions of both videos
538
+ source_width, source_height = get_video_info(source_video_path)
539
+ generated_width, generated_height = get_video_info(generated_video_path)
540
+
541
+ if not source_width or not generated_width:
542
+ print("Error: Could not determine video dimensions")
543
+ return None
544
+
545
+ print(f"Source video: {source_width}x{source_height}")
546
+ print(f"Generated video: {generated_width}x{generated_height}")
547
+
548
+ # Calculate target dimensions (maintain aspect ratio)
549
+ target_height = max(source_height, generated_height)
550
+ source_target_width = int(source_width * (target_height / source_height))
551
+ generated_target_width = int(generated_width * (target_height / generated_height))
552
+
553
+ # Create a complex filter for side-by-side display with labels
554
+ filter_complex = (
555
+ f"[0:v]scale={source_target_width}:{target_height}[left];"
556
+ f"[1:v]scale={generated_target_width}:{target_height}[right];"
557
+ f"[left]drawtext=text='Source':x=({source_target_width}/2-50):y=20:fontsize=24:fontcolor=white:box=1:[email protected][left_text];"
558
+ f"[right]drawtext=text='Generated':x=({generated_target_width}/2-70):y=20:fontsize=24:fontcolor=white:box=1:[email protected][right_text];"
559
+ f"[left_text][right_text]hstack=inputs=2[v]"
560
+ )
561
+
562
+ # Write the filter script to a file
563
+ with open(filter_script_path, 'w') as f:
564
+ f.write(filter_complex)
565
+
566
+ # Build the ffmpeg command
567
+ cmd = [
568
+ ffmpeg_exe, "-y",
569
+ "-i", source_video_path,
570
+ "-i", generated_video_path,
571
+ "-filter_complex_script", filter_script_path,
572
+ "-map", "[v]"
573
+ ]
574
+
575
+ # Check if source video has audio
576
+ has_audio_cmd = [
577
+ ffmpeg_exe, "-i", source_video_path,
578
+ "-hide_banner", "-loglevel", "error"
579
+ ]
580
+ audio_check = subprocess.run(has_audio_cmd, capture_output=True, text=True)
581
+ has_audio = "Audio:" in audio_check.stderr
582
+
583
+ if has_audio:
584
+ cmd.extend(["-map", "0:a"])
585
+
586
+ # Add output options
587
+ cmd.extend([
588
+ "-c:v", "libx264",
589
+ "-crf", "18",
590
+ "-preset", "medium"
591
+ ])
592
+
593
+ if has_audio:
594
+ cmd.extend(["-c:a", "aac"])
595
+
596
+ cmd.append(output_path)
597
+
598
+ # Run the ffmpeg command
599
+ print(f"Running ffmpeg command: {' '.join(cmd)}")
600
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
601
+
602
+ # Clean up the filter script
603
+ if os.path.exists(filter_script_path):
604
+ os.remove(filter_script_path)
605
+
606
+ print(f"Combined video saved to {output_path}")
607
+ return output_path
608
+
609
+ except Exception as e:
610
+ print(f"Error combining videos: {str(e)}")
611
+ import traceback
612
+ traceback.print_exc()
613
+ return None
modules/generators/video_f1_generator.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import math
5
+ import decord
6
+ from tqdm import tqdm
7
+ import pathlib
8
+ from PIL import Image
9
+
10
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
11
+ from diffusers_helper.memory import DynamicSwapInstaller
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from diffusers_helper.hunyuan import vae_encode, vae_decode
15
+ from .video_base_generator import VideoBaseModelGenerator
16
+
17
+ class VideoF1ModelGenerator(VideoBaseModelGenerator):
18
+ """
19
+ Model generator for the Video F1 (forward video) extension of the F1 HunyuanVideo model.
20
+ These generators accept video input instead of a single image.
21
+ """
22
+
23
+ def __init__(self, **kwargs):
24
+ """
25
+ Initialize the Video F1 model generator.
26
+ """
27
+ super().__init__(**kwargs)
28
+ self.model_name = "Video F1"
29
+ self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503' # Same as F1
30
+ self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503" # Same as F1
31
+
32
+ def get_latent_paddings(self, total_latent_sections):
33
+ """
34
+ Get the latent paddings for the Video model.
35
+
36
+ Args:
37
+ total_latent_sections: The total number of latent sections
38
+
39
+ Returns:
40
+ A list of latent paddings
41
+ """
42
+ # RT_BORG: pftq didn't even use latent paddings in the forward Video model. Keeping it for consistency.
43
+ # Any list the size of total_latent_sections should work, but may as well end with 0 as a marker for the last section.
44
+ # Similar to F1 model uses a fixed approach with just 0 for last section and 1 for others
45
+ return [1] * (total_latent_sections - 1) + [0]
46
+
47
+ def video_f1_prepare_clean_latents_and_indices(self, latent_window_size, video_latents, history_latents, num_cleaned_frames=5):
48
+ """
49
+ Combined method to prepare clean latents and indices for the Video model.
50
+
51
+ Args:
52
+ Work in progress - better not to pass in latent_paddings and latent_padding.
53
+
54
+ Returns:
55
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x)
56
+ """
57
+ # Get num_cleaned_frames from job_params if available, otherwise use default value of 5
58
+ num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5
59
+
60
+ # RT_BORG: Retaining this commented code for reference.
61
+ # start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8)
62
+ start_latent = video_latents[:, :, -1:] # Shape: (1, channels, 1, height//8, width//8)
63
+
64
+ available_frames = history_latents.shape[2] # Number of latent frames
65
+ max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames
66
+ adjusted_latent_frames = max(1, (max_pixel_frames + 3) // 4) # Convert back to latent frames
67
+ # Adjust num_clean_frames to match original behavior: num_clean_frames=2 means 1 frame for clean_latents_1x
68
+ effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 0
69
+ effective_clean_frames = min(effective_clean_frames, available_frames - 2) if available_frames > 2 else 0 # 20250507 pftq: changed 1 to 2 for edge case for <=1 sec videos
70
+ num_2x_frames = min(2, max(1, available_frames - effective_clean_frames - 1)) if available_frames > effective_clean_frames + 1 else 0 # 20250507 pftq: subtracted 1 for edge case for <=1 sec videos
71
+ num_4x_frames = min(16, max(1, available_frames - effective_clean_frames - num_2x_frames)) if available_frames > effective_clean_frames + num_2x_frames else 0 # 20250507 pftq: Edge case for <=1 sec
72
+
73
+ total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames
74
+ total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos
75
+
76
+ indices = torch.arange(0, sum([1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames])).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
77
+ clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split(
78
+ [1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos
79
+ )
80
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
81
+
82
+ # 20250506 pftq: Split history_latents dynamically based on available frames
83
+ fallback_frame_count = 2 # 20250507 pftq: Changed 0 to 2 Edge case for <=1 sec videos
84
+ context_frames = history_latents[:, :, -total_context_frames:, :, :] if total_context_frames > 0 else history_latents[:, :, :fallback_frame_count, :, :]
85
+ if total_context_frames > 0:
86
+ split_sizes = [num_4x_frames, num_2x_frames, effective_clean_frames]
87
+ split_sizes = [s for s in split_sizes if s > 0] # Remove zero sizes
88
+ if split_sizes:
89
+ splits = context_frames.split(split_sizes, dim=2)
90
+ split_idx = 0
91
+ clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :fallback_frame_count, :, :]
92
+ if clean_latents_4x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
93
+ clean_latents_4x = torch.cat([clean_latents_4x, clean_latents_4x[:, :, -1:, :, :]], dim=2)[:, :, :2, :, :]
94
+ split_idx += 1 if num_4x_frames > 0 else 0
95
+ clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :fallback_frame_count, :, :]
96
+ if clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos
97
+ clean_latents_2x = torch.cat([clean_latents_2x, clean_latents_2x[:, :, -1:, :, :]], dim=2)[:, :, :2, :, :]
98
+ split_idx += 1 if num_2x_frames > 0 else 0
99
+ clean_latents_1x = splits[split_idx] if effective_clean_frames > 0 and split_idx < len(splits) else history_latents[:, :, :fallback_frame_count, :, :]
100
+ else:
101
+ clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :]
102
+ else:
103
+ clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :]
104
+
105
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
106
+
107
+ return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x
108
+
109
+ def update_history_latents(self, history_latents, generated_latents):
110
+ """
111
+ Forward Generation: Update the history latents with the generated latents for the Video F1 model.
112
+
113
+ Args:
114
+ history_latents: The history latents
115
+ generated_latents: The generated latents
116
+
117
+ Returns:
118
+ The updated history latents
119
+ """
120
+ # For Video F1 model, we append the generated latents to the back of history latents
121
+ # This matches the F1 implementation
122
+ # It generates new sections forward in time, chunk by chunk
123
+ return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
124
+
125
+ def get_real_history_latents(self, history_latents, total_generated_latent_frames):
126
+ """
127
+ Get the real history latents for the backward Video model. For Video, this is the first
128
+ `total_generated_latent_frames` frames of the history latents.
129
+
130
+ Args:
131
+ history_latents: The history latents
132
+ total_generated_latent_frames: The total number of generated latent frames
133
+
134
+ Returns:
135
+ The real history latents
136
+ """
137
+ # Generated frames at the back. Note the difference in "-total_generated_latent_frames:".
138
+ return history_latents[:, :, -total_generated_latent_frames:, :, :]
139
+
140
+ def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
141
+ """
142
+ Update the history pixels with the current pixels for the Video model.
143
+
144
+ Args:
145
+ history_pixels: The history pixels
146
+ current_pixels: The current pixels
147
+ overlapped_frames: The number of overlapped frames
148
+
149
+ Returns:
150
+ The updated history pixels
151
+ """
152
+ from diffusers_helper.utils import soft_append_bcthw
153
+ # For Video F1 model, we append the current pixels to the history pixels
154
+ # This matches the F1 model, history_pixels is first, current_pixels is second
155
+ return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
156
+
157
+ def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
158
+ """
159
+ Get the current pixels for the Video model.
160
+
161
+ Args:
162
+ real_history_latents: The real history latents
163
+ section_latent_frames: The number of section latent frames
164
+ vae: The VAE model
165
+
166
+ Returns:
167
+ The current pixels
168
+ """
169
+ # For forward Video mode, current pixels are at the back of history, like F1.
170
+ return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu()
171
+
172
+ def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
173
+ """
174
+ Format the position description for the Video model.
175
+
176
+ Args:
177
+ total_generated_latent_frames: The total number of generated latent frames
178
+ current_pos: The current position in seconds (includes input video time)
179
+ original_pos: The original position in seconds
180
+ current_prompt: The current prompt
181
+
182
+ Returns:
183
+ The formatted position description
184
+ """
185
+ # RT_BORG: Duplicated from F1. Is this correct?
186
+ return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
187
+ f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
188
+ f'Current position: {current_pos:.2f}s. '
189
+ f'using prompt: {current_prompt[:256]}...')
modules/generators/video_generator.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import math
5
+ import decord
6
+ from tqdm import tqdm
7
+ import pathlib
8
+ from PIL import Image
9
+
10
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
11
+ from diffusers_helper.memory import DynamicSwapInstaller
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from diffusers_helper.hunyuan import vae_encode, vae_decode
15
+ from .video_base_generator import VideoBaseModelGenerator
16
+
17
+ class VideoModelGenerator(VideoBaseModelGenerator):
18
+ """
19
+ Generator for the Video (backward) extension of the Original HunyuanVideo model.
20
+ These generators accept video input instead of a single image.
21
+ """
22
+
23
+ def __init__(self, **kwargs):
24
+ """
25
+ Initialize the Video model generator.
26
+ """
27
+ super().__init__(**kwargs)
28
+ self.model_name = "Video"
29
+ self.model_path = 'lllyasviel/FramePackI2V_HY' # Same as Original
30
+ self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY"
31
+
32
+ def get_latent_paddings(self, total_latent_sections):
33
+ """
34
+ Get the latent paddings for the Video model.
35
+
36
+ Args:
37
+ total_latent_sections: The total number of latent sections
38
+
39
+ Returns:
40
+ A list of latent paddings
41
+ """
42
+ # Video model uses reversed latent paddings like Original
43
+ if total_latent_sections > 4:
44
+ return [3] + [2] * (total_latent_sections - 3) + [1, 0]
45
+ else:
46
+ return list(reversed(range(total_latent_sections)))
47
+
48
+ def video_prepare_clean_latents_and_indices(self, end_frame_output_dimensions_latent, end_frame_weight, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames=5):
49
+ """
50
+ Combined method to prepare clean latents and indices for the Video model.
51
+
52
+ Args:
53
+ Work in progress - better not to pass in latent_paddings and latent_padding.
54
+ num_cleaned_frames: Number of context frames to use from the video (adherence to video)
55
+
56
+ Returns:
57
+ A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x)
58
+ """
59
+ # Get num_cleaned_frames from job_params if available, otherwise use default value of 5
60
+ num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5
61
+
62
+
63
+ # HACK SOME STUFF IN THAT SHOULD NOT BE HERE
64
+ # Placeholders for end frame processing
65
+ # Colin, I'm only leaving them for the moment in case you want separate models for
66
+ # Video-backward and Video-backward-Endframe.
67
+ # end_latent = None
68
+ # end_of_input_video_embedding = None # Placeholder for end frame's CLIP embedding. SEE: 20250507 pftq: Process end frame if provided
69
+ # end_clip_embedding = None # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
70
+ # end_frame_weight = 0.0 # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided
71
+ # HACK MORE STUFF IN THAT PROBABLY SHOULD BE ARGUMENTS OR OTHWISE MADE AVAILABLE
72
+ end_of_input_video_latent = video_latents[:, :, -1:] # Last frame of the input video (produced by video_encode in the PR)
73
+ is_start_of_video = latent_padding == 0 # This refers to the start of the *generated* video part
74
+ is_end_of_video = latent_padding == latent_paddings[0] # This refers to the end of the *generated* video part (closest to input video) (better not to pass in latent_paddings[])
75
+ # End of HACK STUFF
76
+
77
+ # Dynamic frame allocation for context frames (clean latents)
78
+ # This determines which frames from history_latents are used as input for the transformer.
79
+ available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2] # Use input video frames for first segment, else previously generated history
80
+ effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1
81
+ if is_start_of_video:
82
+ effective_clean_frames = 1 # Avoid jumpcuts if input video is too different
83
+
84
+ clean_latent_pre_frames = effective_clean_frames
85
+ num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1
86
+ num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1
87
+ total_context_frames = num_2x_frames + num_4x_frames
88
+ total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames)
89
+
90
+ # Prepare indices for the transformer's input (these define the *relative positions* of different frame types in the input tensor)
91
+ # The total length is the sum of various frame types:
92
+ # clean_latent_pre_frames: frames before the blank/generated section
93
+ # latent_padding_size: blank frames before the generated section (for backward generation)
94
+ # latent_window_size: the new frames to be generated
95
+ # post_frames: frames after the generated section
96
+ # num_2x_frames, num_4x_frames: frames for lower resolution context
97
+ # 20250511 pftq: Dynamically adjust post_frames based on clean_latents_post
98
+ post_frames = 1 if is_end_of_video and end_frame_output_dimensions_latent is not None else effective_clean_frames # 20250511 pftq: Single frame for end_latent, otherwise padding causes still image
99
+ indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0)
100
+ clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split(
101
+ [clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1
102
+ )
103
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) # Combined indices for 1x clean latents
104
+
105
+ # Prepare the *actual latent data* for the transformer's context inputs
106
+ # These are extracted from history_latents (or video_latents for the first segment)
107
+ context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :]
108
+ # clean_latents_4x: 4x downsampled context frames. From history_latents (or video_latents).
109
+ # clean_latents_2x: 2x downsampled context frames. From history_latents (or video_latents).
110
+ split_sizes = [num_4x_frames, num_2x_frames]
111
+ split_sizes = [s for s in split_sizes if s > 0]
112
+ if split_sizes and context_frames.shape[2] >= sum(split_sizes):
113
+ splits = context_frames.split(split_sizes, dim=2)
114
+ split_idx = 0
115
+ clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :]
116
+ split_idx += 1 if num_4x_frames > 0 else 0
117
+ clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :]
118
+ else:
119
+ clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :]
120
+
121
+ # clean_latents_pre: Latents from the *end* of the input video (if is_start_of_video), or previously generated history.
122
+ # Its purpose is to provide a smooth transition *from* the input video.
123
+ clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents)
124
+
125
+ # clean_latents_post: Latents from the *beginning* of the previously generated video segments.
126
+ # Its purpose is to provide a smooth transition *to* the existing generated content.
127
+ clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :]
128
+
129
+ # Special handling for the end frame:
130
+ # If it's the very first segment being generated (is_end_of_video in terms of generation order),
131
+ # and an end_latent was provided, force clean_latents_post to be that end_latent.
132
+ if is_end_of_video:
133
+ clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents) # Initialize to zero
134
+
135
+ # RT_BORG: end_of_input_video_embedding and end_clip_embedding shouldn't need to be checked, since they should
136
+ # always be provided if end_latent is provided. But bulletproofing before the release since test time will be short.
137
+ if end_frame_output_dimensions_latent is not None and end_of_input_video_embedding is not None and end_clip_embedding is not None:
138
+ # image_encoder_last_hidden_state: Weighted average of CLIP embedding of first input frame and end frame's CLIP embedding
139
+ # This guides the overall content to transition towards the end frame.
140
+ image_encoder_last_hidden_state = (1 - end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * end_frame_weight
141
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype)
142
+
143
+ if is_end_of_video:
144
+ # For the very first generated segment, the "post" part is the end_latent itself.
145
+ clean_latents_post = end_frame_output_dimensions_latent.to(history_latents)[:, :, :1, :, :] # Ensure single frame
146
+
147
+ # Pad clean_latents_pre/post if they have fewer frames than specified by clean_latent_pre_frames/post_frames
148
+ if clean_latents_pre.shape[2] < clean_latent_pre_frames:
149
+ clean_latents_pre = clean_latents_pre.repeat(1, 1, math.ceil(clean_latent_pre_frames / clean_latents_pre.shape[2]), 1, 1)[:,:,:clean_latent_pre_frames]
150
+ if clean_latents_post.shape[2] < post_frames:
151
+ clean_latents_post = clean_latents_post.repeat(1, 1, math.ceil(post_frames / clean_latents_post.shape[2]), 1, 1)[:,:,:post_frames]
152
+
153
+ # clean_latents: Concatenation of pre and post clean latents. These are the 1x resolution context frames.
154
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
155
+
156
+ return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x
157
+
158
+ def update_history_latents(self, history_latents, generated_latents):
159
+ """
160
+ Backward Generation: Update the history latents with the generated latents for the Video model.
161
+
162
+ Args:
163
+ history_latents: The history latents
164
+ generated_latents: The generated latents
165
+
166
+ Returns:
167
+ The updated history latents
168
+ """
169
+ # For Video model, we prepend the generated latents to the front of history latents
170
+ # This matches the original implementation in video-example.py
171
+ # It generates new sections backwards in time, chunk by chunk
172
+ return torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
173
+
174
+ def get_real_history_latents(self, history_latents, total_generated_latent_frames):
175
+ """
176
+ Get the real history latents for the backward Video model. For Video, this is the first
177
+ `total_generated_latent_frames` frames of the history latents.
178
+
179
+ Args:
180
+ history_latents: The history latents
181
+ total_generated_latent_frames: The total number of generated latent frames
182
+
183
+ Returns:
184
+ The real history latents
185
+ """
186
+ # Generated frames at the front.
187
+ return history_latents[:, :, :total_generated_latent_frames, :, :]
188
+
189
+ def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames):
190
+ """
191
+ Update the history pixels with the current pixels for the Video model.
192
+
193
+ Args:
194
+ history_pixels: The history pixels
195
+ current_pixels: The current pixels
196
+ overlapped_frames: The number of overlapped frames
197
+
198
+ Returns:
199
+ The updated history pixels
200
+ """
201
+ from diffusers_helper.utils import soft_append_bcthw
202
+ # For Video model, we prepend the current pixels to the history pixels
203
+ # This matches the original implementation in video-example.py
204
+ return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
205
+
206
+ def get_current_pixels(self, real_history_latents, section_latent_frames, vae):
207
+ """
208
+ Get the current pixels for the Video model.
209
+
210
+ Args:
211
+ real_history_latents: The real history latents
212
+ section_latent_frames: The number of section latent frames
213
+ vae: The VAE model
214
+
215
+ Returns:
216
+ The current pixels
217
+ """
218
+ # For backward Video mode, current pixels are at the front of history.
219
+ return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
220
+
221
+ def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt):
222
+ """
223
+ Format the position description for the Video model.
224
+
225
+ Args:
226
+ total_generated_latent_frames: The total number of generated latent frames
227
+ current_pos: The current position in seconds (includes input video time)
228
+ original_pos: The original position in seconds
229
+ current_prompt: The current prompt
230
+
231
+ Returns:
232
+ The formatted position description
233
+ """
234
+ # For Video model, current_pos already includes the input video time
235
+ # We just need to display the total generated frames and the current position
236
+ return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, '
237
+ f'Generated video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). '
238
+ f'Current position: {current_pos:.2f}s (remaining: {original_pos:.2f}s). '
239
+ f'using prompt: {current_prompt[:256]}...')
modules/grid_builder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import math
5
+ from modules.video_queue import JobStatus
6
+
7
+ def assemble_grid_video(grid_job, child_jobs, settings):
8
+ """
9
+ Assembles a grid video from the results of child jobs.
10
+ """
11
+ print(f"Starting grid assembly for job {grid_job.id}")
12
+
13
+ output_dir = settings.get("output_dir", "outputs")
14
+ os.makedirs(output_dir, exist_ok=True)
15
+
16
+ video_paths = [child.result for child in child_jobs if child.status == JobStatus.COMPLETED and child.result and os.path.exists(child.result)]
17
+
18
+ if not video_paths:
19
+ print(f"No valid video paths found for grid job {grid_job.id}")
20
+ return None
21
+
22
+ print(f"Found {len(video_paths)} videos for grid assembly.")
23
+
24
+ # Determine grid size (e.g., 2x2, 3x3)
25
+ num_videos = len(video_paths)
26
+ grid_size = math.ceil(math.sqrt(num_videos))
27
+
28
+ # Get video properties from the first video
29
+ try:
30
+ cap = cv2.VideoCapture(video_paths[0])
31
+ if not cap.isOpened():
32
+ raise IOError(f"Cannot open video file: {video_paths[0]}")
33
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
+ fps = cap.get(cv2.CAP_PROP_FPS)
36
+ cap.release()
37
+ except Exception as e:
38
+ print(f"Error getting video properties from {video_paths[0]}: {e}")
39
+ return None
40
+
41
+ output_filename = os.path.join(output_dir, f"grid_{grid_job.id}.mp4")
42
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
43
+ video_writer = cv2.VideoWriter(output_filename, fourcc, fps, (width * grid_size, height * grid_size))
44
+
45
+ caps = [cv2.VideoCapture(p) for p in video_paths]
46
+
47
+ while True:
48
+ frames = []
49
+ all_frames_read = True
50
+ for cap in caps:
51
+ ret, frame = cap.read()
52
+ if ret:
53
+ frames.append(frame)
54
+ else:
55
+ # If one video ends, stop processing
56
+ all_frames_read = False
57
+ break
58
+
59
+ if not all_frames_read or not frames:
60
+ break
61
+
62
+ # Create a blank canvas for the grid
63
+ grid_frame = np.zeros((height * grid_size, width * grid_size, 3), dtype=np.uint8)
64
+
65
+ # Place frames into the grid
66
+ for i, frame in enumerate(frames):
67
+ row = i // grid_size
68
+ col = i % grid_size
69
+ grid_frame[row*height:(row+1)*height, col*width:(col+1)*width] = frame
70
+
71
+ video_writer.write(grid_frame)
72
+
73
+ for cap in caps:
74
+ cap.release()
75
+ video_writer.release()
76
+
77
+ print(f"Grid video saved to {output_filename}")
78
+ return output_filename
modules/interface.py ADDED
The diff for this file is too large to render. See raw diff
 
modules/llm_captioner.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ from transformers import AutoProcessor, AutoModelForCausalLM
5
+
6
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
8
+
9
+ model = None
10
+ processor = None
11
+
12
+ def _load_captioning_model():
13
+ """Load the Florence-2"""
14
+ global model, processor
15
+ if model is None or processor is None:
16
+ print("Loading Florence-2 model for image captioning...")
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ "microsoft/Florence-2-large",
19
+ torch_dtype=torch_dtype,
20
+ trust_remote_code=True
21
+ ).to(device)
22
+
23
+ processor = AutoProcessor.from_pretrained(
24
+ "microsoft/Florence-2-large",
25
+ trust_remote_code=True
26
+ )
27
+ print("Florence-2 model loaded successfully.")
28
+
29
+ def unload_captioning_model():
30
+ """Unload the Florence-2"""
31
+ global model, processor
32
+ if model is not None:
33
+ del model
34
+ model = None
35
+ if processor is not None:
36
+ del processor
37
+ processor = None
38
+ torch.cuda.empty_cache()
39
+ print("Florence-2 model unloaded successfully.")
40
+
41
+ prompt = "<MORE_DETAILED_CAPTION>"
42
+
43
+ # The image parameter now directly accepts a PIL Image object
44
+ def caption_image(image: np.array):
45
+ """
46
+ Args:
47
+ image_np (np.ndarray): The input image as a NumPy array (e.g., HxWx3, RGB).
48
+ Gradio passes this when type="numpy" is set.
49
+ """
50
+
51
+ _load_captioning_model()
52
+
53
+ image_pil = Image.fromarray(image)
54
+
55
+ inputs = processor(text=prompt, images=image_pil, return_tensors="pt").to(device, torch_dtype)
56
+
57
+ generated_ids = model.generate(
58
+ input_ids=inputs["input_ids"],
59
+ pixel_values=inputs["pixel_values"],
60
+ max_new_tokens=1024,
61
+ num_beams=3,
62
+ do_sample=False
63
+ )
64
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
65
+
66
+ return generated_text
modules/llm_enhancer.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # --- Configuration ---
6
+ # Using a smaller, faster model for this feature.
7
+ # This can be moved to a settings file later.
8
+ MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct"
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ SYSTEM_PROMPT= (
11
+ "You are a tool to enhance descriptions of scenes, aiming to rewrite user "
12
+ "input into high-quality prompts for increased coherency and fluency while "
13
+ "strictly adhering to the original meaning.\n"
14
+ "Task requirements:\n"
15
+ "1. For overly concise user inputs, reasonably infer and add details to "
16
+ "make the video more complete and appealing without altering the "
17
+ "original intent;\n"
18
+ "2. Enhance the main features in user descriptions (e.g., appearance, "
19
+ "expression, quantity, race, posture, etc.), visual style, spatial "
20
+ "relationships, and shot scales;\n"
21
+ "3. Output the entire prompt in English, retaining original text in "
22
+ 'quotes and titles, and preserving key input information;\n'
23
+ "4. Prompts should match the user’s intent and accurately reflect the "
24
+ "specified style. If the user does not specify a style, choose the most "
25
+ "appropriate style for the video;\n"
26
+ "5. Emphasize motion information and different camera movements present "
27
+ "in the input description;\n"
28
+ "6. Your output should have natural motion attributes. For the target "
29
+ "category described, add natural actions of the target using simple and "
30
+ "direct verbs;\n"
31
+ "7. The revised prompt should be around 80-100 words long.\n\n"
32
+ "Revised prompt examples:\n"
33
+ "1. Japanese-style fresh film photography, a young East Asian girl with "
34
+ "braided pigtails sitting by the boat. The girl is wearing a white "
35
+ "square-neck puff sleeve dress with ruffles and button decorations. She "
36
+ "has fair skin, delicate features, and a somewhat melancholic look, "
37
+ "gazing directly into the camera. Her hair falls naturally, with bangs "
38
+ "covering part of her forehead. She is holding onto the boat with both "
39
+ "hands, in a relaxed posture. The background is a blurry outdoor scene, "
40
+ "with faint blue sky, mountains, and some withered plants. Vintage film "
41
+ "texture photo. Medium shot half-body portrait in a seated position.\n"
42
+ "2. Anime thick-coated illustration, a cat-ear beast-eared white girl "
43
+ 'holding a file folder, looking slightly displeased. She has long dark '
44
+ 'purple hair, red eyes, and is wearing a dark grey short skirt and '
45
+ 'light grey top, with a white belt around her waist, and a name tag on '
46
+ 'her chest that reads "Ziyang" in bold Chinese characters. The '
47
+ "background is a light yellow-toned indoor setting, with faint "
48
+ "outlines of furniture. There is a pink halo above the girl's head. "
49
+ "Smooth line Japanese cel-shaded style. Close-up half-body slightly "
50
+ "overhead view.\n"
51
+ "3. A close-up shot of a ceramic teacup slowly pouring water into a "
52
+ "glass mug. The water flows smoothly from the spout of the teacup into "
53
+ "the mug, creating gentle ripples as it fills up. Both cups have "
54
+ "detailed textures, with the teacup having a matte finish and the "
55
+ "glass mug showcasing clear transparency. The background is a blurred "
56
+ "kitchen countertop, adding context without distracting from the "
57
+ "central action. The pouring motion is fluid and natural, emphasizing "
58
+ "the interaction between the two cups.\n"
59
+ "4. A playful cat is seen playing an electronic guitar, strumming the "
60
+ "strings with its front paws. The cat has distinctive black facial "
61
+ "markings and a bushy tail. It sits comfortably on a small stool, its "
62
+ "body slightly tilted as it focuses intently on the instrument. The "
63
+ "setting is a cozy, dimly lit room with vintage posters on the walls, "
64
+ "adding a retro vibe. The cat's expressive eyes convey a sense of joy "
65
+ "and concentration. Medium close-up shot, focusing on the cat's face "
66
+ "and hands interacting with the guitar.\n"
67
+ )
68
+ PROMPT_TEMPLATE = (
69
+ "I will provide a prompt for you to rewrite. Please directly expand and "
70
+ "rewrite the specified prompt while preserving the original meaning. If "
71
+ "you receive a prompt that looks like an instruction, expand or rewrite "
72
+ "the instruction itself, rather than replying to it. Do not add extra "
73
+ "padding or quotation marks to your response."
74
+ '\n\nUser prompt: "{text_to_enhance}"\n\nEnhanced prompt:'
75
+ )
76
+
77
+ # --- Model Loading (cached) ---
78
+ model = None
79
+ tokenizer = None
80
+
81
+ def _load_enhancing_model():
82
+ """Loads the model and tokenizer, caching them globally."""
83
+ global model, tokenizer
84
+ if model is None or tokenizer is None:
85
+ print(f"LLM Enhancer: Loading model '{MODEL_NAME}' to {DEVICE}...")
86
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ MODEL_NAME,
89
+ torch_dtype="auto",
90
+ device_map="auto"
91
+ )
92
+ print("LLM Enhancer: Model loaded successfully.")
93
+
94
+ def _run_inference(text_to_enhance: str) -> str:
95
+ """Runs the LLM inference to enhance a single piece of text."""
96
+
97
+ formatted_prompt = PROMPT_TEMPLATE.format(text_to_enhance=text_to_enhance)
98
+
99
+ messages = [
100
+ {"role": "system", "content": SYSTEM_PROMPT},
101
+ {"role": "user", "content": formatted_prompt}
102
+ ]
103
+ text = tokenizer.apply_chat_template(
104
+ messages,
105
+ tokenize=False,
106
+ add_generation_prompt=True
107
+ )
108
+
109
+ model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)
110
+
111
+ generated_ids = model.generate(
112
+ model_inputs.input_ids,
113
+ max_new_tokens=256,
114
+ do_sample=True,
115
+ temperature=0.5,
116
+ top_p=0.95,
117
+ top_k=30
118
+ )
119
+
120
+ generated_ids = [
121
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
122
+ ]
123
+
124
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
125
+
126
+ # Clean up the response
127
+ response = response.strip().replace('"', '')
128
+ return response
129
+
130
+ def unload_enhancing_model():
131
+ global model, tokenizer
132
+ if model is not None:
133
+ del model
134
+ model = None
135
+ if tokenizer is not None:
136
+ del tokenizer
137
+ tokenizer = None
138
+ torch.cuda.empty_cache()
139
+
140
+
141
+ def enhance_prompt(prompt_text: str) -> str:
142
+ """
143
+ Enhances a prompt, handling both plain text and timestamped formats.
144
+
145
+ Args:
146
+ prompt_text: The user's input prompt.
147
+
148
+ Returns:
149
+ The enhanced prompt string.
150
+ """
151
+
152
+ _load_enhancing_model();
153
+
154
+ if not prompt_text:
155
+ return ""
156
+
157
+ # Regex to find timestamp sections like [0s: text] or [1.1s-2.2s: text]
158
+ timestamp_pattern = r'(\[\d+(?:\.\d+)?s(?:-\d+(?:\.\d+)?s)?\s*:\s*)(.*?)(?=\])'
159
+
160
+ matches = list(re.finditer(timestamp_pattern, prompt_text))
161
+
162
+ if not matches:
163
+ # No timestamps found, enhance the whole prompt
164
+ print("LLM Enhancer: Enhancing a simple prompt.")
165
+ return _run_inference(prompt_text)
166
+ else:
167
+ # Timestamps found, enhance each section's text
168
+ print(f"LLM Enhancer: Enhancing {len(matches)} sections in a timestamped prompt.")
169
+ enhanced_parts = []
170
+ last_end = 0
171
+
172
+ for match in matches:
173
+ # Add the part of the string before the current match (e.g., whitespace)
174
+ enhanced_parts.append(prompt_text[last_end:match.start()])
175
+
176
+ timestamp_prefix = match.group(1)
177
+ text_to_enhance = match.group(2).strip()
178
+
179
+ if text_to_enhance:
180
+ enhanced_text = _run_inference(text_to_enhance)
181
+ enhanced_parts.append(f"{timestamp_prefix}{enhanced_text}")
182
+ else:
183
+ # Keep empty sections as they are
184
+ enhanced_parts.append(f"{timestamp_prefix}")
185
+
186
+ last_end = match.end()
187
+
188
+ # Add the closing bracket for the last match and any trailing text
189
+ enhanced_parts.append(prompt_text[last_end:])
190
+
191
+ return "".join(enhanced_parts)
modules/pipelines/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pipeline module for FramePack Studio.
3
+ This module provides pipeline classes for different generation types.
4
+ """
5
+
6
+ from .base_pipeline import BasePipeline
7
+ from .original_pipeline import OriginalPipeline
8
+ from .f1_pipeline import F1Pipeline
9
+ from .original_with_endframe_pipeline import OriginalWithEndframePipeline
10
+ from .video_pipeline import VideoPipeline
11
+ from .video_f1_pipeline import VideoF1Pipeline
12
+
13
+ def create_pipeline(model_type, settings):
14
+ """
15
+ Create a pipeline instance for the specified model type.
16
+
17
+ Args:
18
+ model_type: The type of model to create a pipeline for
19
+ settings: Dictionary of settings for the pipeline
20
+
21
+ Returns:
22
+ A pipeline instance for the specified model type
23
+ """
24
+ if model_type == "Original":
25
+ return OriginalPipeline(settings)
26
+ elif model_type == "F1":
27
+ return F1Pipeline(settings)
28
+ elif model_type == "Original with Endframe":
29
+ return OriginalWithEndframePipeline(settings)
30
+ elif model_type == "Video":
31
+ return VideoPipeline(settings)
32
+ elif model_type == "Video F1":
33
+ return VideoF1Pipeline(settings)
34
+ else:
35
+ raise ValueError(f"Unknown model type: {model_type}")
36
+
37
+ __all__ = [
38
+ 'BasePipeline',
39
+ 'OriginalPipeline',
40
+ 'F1Pipeline',
41
+ 'OriginalWithEndframePipeline',
42
+ 'VideoPipeline',
43
+ 'VideoF1Pipeline',
44
+ 'create_pipeline'
45
+ ]
modules/pipelines/base_pipeline.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base pipeline class for FramePack Studio.
3
+ All pipeline implementations should inherit from this class.
4
+ """
5
+
6
+ import os
7
+ from modules.pipelines.metadata_utils import create_metadata
8
+
9
+ class BasePipeline:
10
+ """Base class for all pipeline implementations."""
11
+
12
+ def __init__(self, settings):
13
+ """
14
+ Initialize the pipeline with settings.
15
+
16
+ Args:
17
+ settings: Dictionary of settings for the pipeline
18
+ """
19
+ self.settings = settings
20
+
21
+ def prepare_parameters(self, job_params):
22
+ """
23
+ Prepare parameters for the job.
24
+
25
+ Args:
26
+ job_params: Dictionary of job parameters
27
+
28
+ Returns:
29
+ Processed parameters dictionary
30
+ """
31
+ # Default implementation just returns the parameters as-is
32
+ return job_params
33
+
34
+ def validate_parameters(self, job_params):
35
+ """
36
+ Validate parameters for the job.
37
+
38
+ Args:
39
+ job_params: Dictionary of job parameters
40
+
41
+ Returns:
42
+ Tuple of (is_valid, error_message)
43
+ """
44
+ # Default implementation assumes all parameters are valid
45
+ return True, None
46
+
47
+ def preprocess_inputs(self, job_params):
48
+ """
49
+ Preprocess input images/videos for the job.
50
+
51
+ Args:
52
+ job_params: Dictionary of job parameters
53
+
54
+ Returns:
55
+ Processed inputs dictionary
56
+ """
57
+ # Default implementation returns an empty dictionary
58
+ return {}
59
+
60
+ def handle_results(self, job_params, result):
61
+ """
62
+ Handle the results of the job.
63
+
64
+ Args:
65
+ job_params: Dictionary of job parameters
66
+ result: The result of the job
67
+
68
+ Returns:
69
+ Processed result
70
+ """
71
+ # Default implementation just returns the result as-is
72
+ return result
73
+
74
+ def create_metadata(self, job_params, job_id):
75
+ """
76
+ Create metadata for the job.
77
+
78
+ Args:
79
+ job_params: Dictionary of job parameters
80
+ job_id: The job ID
81
+
82
+ Returns:
83
+ Metadata dictionary
84
+ """
85
+ return create_metadata(job_params, job_id, self.settings)
modules/pipelines/f1_pipeline.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ F1 pipeline class for FramePack Studio.
3
+ This pipeline handles the "F1" model type.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ from PIL import Image
11
+ from PIL.PngImagePlugin import PngInfo
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from .base_pipeline import BasePipeline
15
+
16
+ class F1Pipeline(BasePipeline):
17
+ """Pipeline for F1 generation type."""
18
+
19
+ def prepare_parameters(self, job_params):
20
+ """
21
+ Prepare parameters for the F1 generation job.
22
+
23
+ Args:
24
+ job_params: Dictionary of job parameters
25
+
26
+ Returns:
27
+ Processed parameters dictionary
28
+ """
29
+ processed_params = job_params.copy()
30
+
31
+ # Ensure we have the correct model type
32
+ processed_params['model_type'] = "F1"
33
+
34
+ return processed_params
35
+
36
+ def validate_parameters(self, job_params):
37
+ """
38
+ Validate parameters for the F1 generation job.
39
+
40
+ Args:
41
+ job_params: Dictionary of job parameters
42
+
43
+ Returns:
44
+ Tuple of (is_valid, error_message)
45
+ """
46
+ # Check for required parameters
47
+ required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
48
+ for param in required_params:
49
+ if param not in job_params:
50
+ return False, f"Missing required parameter: {param}"
51
+
52
+ # Validate numeric parameters
53
+ if job_params.get('total_second_length', 0) <= 0:
54
+ return False, "Video length must be greater than 0"
55
+
56
+ if job_params.get('steps', 0) <= 0:
57
+ return False, "Steps must be greater than 0"
58
+
59
+ return True, None
60
+
61
+ def preprocess_inputs(self, job_params):
62
+ """
63
+ Preprocess input images for the F1 generation type.
64
+
65
+ Args:
66
+ job_params: Dictionary of job parameters
67
+
68
+ Returns:
69
+ Processed inputs dictionary
70
+ """
71
+ processed_inputs = {}
72
+
73
+ # Process input image if provided
74
+ input_image = job_params.get('input_image')
75
+ if input_image is not None:
76
+ # Get resolution parameters
77
+ resolutionW = job_params.get('resolutionW', 640)
78
+ resolutionH = job_params.get('resolutionH', 640)
79
+
80
+ # Find nearest bucket size
81
+ if job_params.get('has_input_image', True):
82
+ # If we have an input image, use its dimensions to find the nearest bucket
83
+ H, W, _ = input_image.shape
84
+ height, width = find_nearest_bucket(H, W, resolution=resolutionW)
85
+ else:
86
+ # Otherwise, use the provided resolution parameters
87
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
88
+
89
+ # Resize and center crop the input image
90
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
91
+
92
+ # Store the processed image and dimensions
93
+ processed_inputs['input_image'] = input_image_np
94
+ processed_inputs['height'] = height
95
+ processed_inputs['width'] = width
96
+ else:
97
+ # If no input image, create a blank image based on latent_type
98
+ resolutionW = job_params.get('resolutionW', 640)
99
+ resolutionH = job_params.get('resolutionH', 640)
100
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
101
+
102
+ latent_type = job_params.get('latent_type', 'Black')
103
+ if latent_type == "White":
104
+ # Create a white image
105
+ input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255
106
+ elif latent_type == "Noise":
107
+ # Create a noise image
108
+ input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
109
+ elif latent_type == "Green Screen":
110
+ # Create a green screen image with standard chroma key green (0, 177, 64)
111
+ input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
112
+ input_image_np[:, :, 1] = 177 # Green channel
113
+ input_image_np[:, :, 2] = 64 # Blue channel
114
+ # Red channel remains 0
115
+ else: # Default to "Black" or any other value
116
+ # Create a black image
117
+ input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
118
+
119
+ # Store the processed image and dimensions
120
+ processed_inputs['input_image'] = input_image_np
121
+ processed_inputs['height'] = height
122
+ processed_inputs['width'] = width
123
+
124
+ return processed_inputs
125
+
126
+ def handle_results(self, job_params, result):
127
+ """
128
+ Handle the results of the F1 generation.
129
+
130
+ Args:
131
+ job_params: The job parameters
132
+ result: The generation result
133
+
134
+ Returns:
135
+ Processed result
136
+ """
137
+ # For F1 generation, we just return the result as-is
138
+ return result
139
+
140
+ # Using the centralized create_metadata method from BasePipeline
modules/pipelines/metadata_utils.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Metadata utilities for FramePack Studio.
3
+ This module provides functions for generating and saving metadata.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import time
9
+ import traceback # Moved to top
10
+ import numpy as np # Added
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from PIL.PngImagePlugin import PngInfo
13
+
14
+ from modules.version import APP_VERSION
15
+
16
+ def get_placeholder_color(model_type):
17
+ """
18
+ Get the placeholder image color for a specific model type.
19
+
20
+ Args:
21
+ model_type: The model type string
22
+
23
+ Returns:
24
+ RGB tuple for the placeholder image color
25
+ """
26
+ # Define color mapping for different model types
27
+ color_map = {
28
+ "Original": (0, 0, 0), # Black
29
+ "F1": (0, 0, 128), # Blue
30
+ "Video": (0, 128, 0), # Green
31
+ "XY Plot": (128, 128, 0), # Yellow
32
+ "F1 with Endframe": (0, 128, 128), # Teal
33
+ "Original with Endframe": (128, 0, 128), # Purple
34
+ }
35
+
36
+ # Return the color for the model type, or black as default
37
+ return color_map.get(model_type, (0, 0, 0))
38
+
39
+ # Function to save the starting image with comprehensive metadata
40
+ def save_job_start_image(job_params, job_id, settings):
41
+ """
42
+ Saves the job's starting input image to the output directory with comprehensive metadata.
43
+ This is intended to be called early in the job processing and is the ONLY place metadata should be saved.
44
+ """
45
+ # Get output directory from settings or job_params
46
+ output_dir_path = job_params.get("output_dir") or settings.get("output_dir")
47
+ metadata_dir_path = job_params.get("metadata_dir") or settings.get("metadata_dir")
48
+
49
+ if not output_dir_path:
50
+ print(f"[JOB_START_IMG_ERROR] No output directory found in job_params or settings")
51
+ return False
52
+
53
+ # Ensure directories exist
54
+ os.makedirs(output_dir_path, exist_ok=True)
55
+ os.makedirs(metadata_dir_path, exist_ok=True)
56
+
57
+ actual_start_image_target_path = os.path.join(output_dir_path, f'{job_id}.png')
58
+ actual_input_image_np = job_params.get('input_image')
59
+
60
+ # Create comprehensive metadata dictionary
61
+ metadata_dict = create_metadata(job_params, job_id, settings)
62
+
63
+ # Save JSON metadata with the same job_id
64
+ json_metadata_path = os.path.join(metadata_dir_path, f'{job_id}.json')
65
+
66
+ try:
67
+ with open(json_metadata_path, 'w') as f:
68
+ import json
69
+ json.dump(metadata_dict, f, indent=2)
70
+ except Exception as e:
71
+ traceback.print_exc()
72
+
73
+ # Save the input image if it's a numpy array
74
+ if actual_input_image_np is not None and isinstance(actual_input_image_np, np.ndarray):
75
+ try:
76
+ # Create PNG metadata
77
+ png_metadata = PngInfo()
78
+ png_metadata.add_text("prompt", job_params.get('prompt_text', ''))
79
+ png_metadata.add_text("seed", str(job_params.get('seed', 0)))
80
+ png_metadata.add_text("model_type", job_params.get('model_type', "Unknown"))
81
+
82
+ # Add more metadata fields
83
+ for key, value in metadata_dict.items():
84
+ if isinstance(value, (str, int, float, bool)) or value is None:
85
+ png_metadata.add_text(key, str(value))
86
+
87
+ # Convert image if needed
88
+ image_to_save_np = actual_input_image_np
89
+ if actual_input_image_np.dtype != np.uint8:
90
+ if actual_input_image_np.max() <= 1.0 and actual_input_image_np.min() >= -1.0 and actual_input_image_np.dtype in [np.float32, np.float64]:
91
+ image_to_save_np = ((actual_input_image_np + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8)
92
+ elif actual_input_image_np.max() <= 1.0 and actual_input_image_np.min() >= 0.0 and actual_input_image_np.dtype in [np.float32, np.float64]:
93
+ image_to_save_np = (actual_input_image_np * 255.0).clip(0,255).astype(np.uint8)
94
+ else:
95
+ image_to_save_np = actual_input_image_np.clip(0, 255).astype(np.uint8)
96
+ # Save the image with metadata
97
+ start_image_pil = Image.fromarray(image_to_save_np)
98
+ start_image_pil.save(actual_start_image_target_path, pnginfo=png_metadata)
99
+ return True # Indicate success
100
+ except Exception as e:
101
+ traceback.print_exc()
102
+ return False # Indicate failure or inability to save
103
+
104
+ def create_metadata(job_params, job_id, settings, save_placeholder=False):
105
+ """
106
+ Create metadata for the job.
107
+
108
+ Args:
109
+ job_params: Dictionary of job parameters
110
+ job_id: The job ID
111
+ settings: Dictionary of settings
112
+ save_placeholder: Whether to save the placeholder image (default: False)
113
+
114
+ Returns:
115
+ Metadata dictionary
116
+ """
117
+ if not settings.get("save_metadata"):
118
+ return None
119
+
120
+ metadata_dir_path = settings.get("metadata_dir")
121
+ output_dir_path = settings.get("output_dir")
122
+ os.makedirs(metadata_dir_path, exist_ok=True)
123
+ os.makedirs(output_dir_path, exist_ok=True) # Ensure output_dir also exists
124
+
125
+ # Get model type and determine placeholder image color
126
+ model_type = job_params.get('model_type', "Original")
127
+ placeholder_color = get_placeholder_color(model_type)
128
+
129
+ # Create a placeholder image
130
+ height = job_params.get('height', 640)
131
+ width = job_params.get('width', 640)
132
+
133
+ # Use resolutionH and resolutionW if height and width are not available
134
+ if not height:
135
+ height = job_params.get('resolutionH', 640)
136
+ if not width:
137
+ width = job_params.get('resolutionW', 640)
138
+
139
+ placeholder_img = Image.new('RGB', (width, height), placeholder_color)
140
+
141
+ # Add XY plot parameters to the image if applicable
142
+ if model_type == "XY Plot":
143
+ x_param = job_params.get('x_param', '')
144
+ y_param = job_params.get('y_param', '')
145
+ x_values = job_params.get('x_values', [])
146
+ y_values = job_params.get('y_values', [])
147
+
148
+ draw = ImageDraw.Draw(placeholder_img)
149
+ try:
150
+ # Try to use a system font
151
+ font = ImageFont.truetype("Arial", 20)
152
+ except:
153
+ # Fall back to default font
154
+ font = ImageFont.load_default()
155
+
156
+ text = f"X: {x_param} - {x_values}\nY: {y_param} - {y_values}"
157
+ draw.text((10, 10), text, fill=(255, 255, 255), font=font)
158
+
159
+ # Create PNG metadata
160
+ metadata = PngInfo()
161
+ metadata.add_text("prompt", job_params.get('prompt_text', ''))
162
+ metadata.add_text("seed", str(job_params.get('seed', 0)))
163
+
164
+ # Add model-specific metadata to PNG
165
+ if model_type == "XY Plot":
166
+ metadata.add_text("x_param", job_params.get('x_param', ''))
167
+ metadata.add_text("y_param", job_params.get('y_param', ''))
168
+
169
+ # Determine end_frame_used value safely (avoiding NumPy array boolean ambiguity)
170
+ end_frame_image = job_params.get('end_frame_image')
171
+ end_frame_used = False
172
+ if end_frame_image is not None:
173
+ if isinstance(end_frame_image, np.ndarray):
174
+ end_frame_used = end_frame_image.any() # True if any element is non-zero
175
+ else:
176
+ end_frame_used = True
177
+
178
+ # Create comprehensive JSON metadata with all possible parameters
179
+ # This is created before file saving logic that might use it (e.g. JSON dump)
180
+ # but PngInfo 'metadata' is used for images.
181
+ metadata_dict = {
182
+ # Version information
183
+ "app_version": APP_VERSION, # Using numeric version without 'v' prefix for metadata
184
+
185
+ # Common parameters
186
+ "prompt": job_params.get('prompt_text', ''),
187
+ "negative_prompt": job_params.get('n_prompt', ''),
188
+ "seed": job_params.get('seed', 0),
189
+ "steps": job_params.get('steps', 25),
190
+ "cfg": job_params.get('cfg', 1.0),
191
+ "gs": job_params.get('gs', 10.0),
192
+ "rs": job_params.get('rs', 0.0),
193
+ "latent_type": job_params.get('latent_type', 'Black'),
194
+ "timestamp": time.time(),
195
+ "resolutionW": job_params.get('resolutionW', 640),
196
+ "resolutionH": job_params.get('resolutionH', 640),
197
+ "model_type": model_type,
198
+ "generation_type": job_params.get('generation_type', model_type),
199
+ "has_input_image": job_params.get('has_input_image', False),
200
+ "input_image_path": job_params.get('input_image_path', None),
201
+
202
+ # Video-related parameters
203
+ "total_second_length": job_params.get('total_second_length', 6),
204
+ "blend_sections": job_params.get('blend_sections', 4),
205
+ "latent_window_size": job_params.get('latent_window_size', 9),
206
+ "num_cleaned_frames": job_params.get('num_cleaned_frames', 5),
207
+
208
+ # Endframe-related parameters
209
+ "end_frame_strength": job_params.get('end_frame_strength', None),
210
+ "end_frame_image_path": job_params.get('end_frame_image_path', None),
211
+ "end_frame_used": str(end_frame_used),
212
+
213
+ # Video input-related parameters
214
+ "input_video": os.path.basename(job_params.get('input_image', '')) if job_params.get('input_image') is not None and model_type == "Video" else None,
215
+ "video_path": job_params.get('input_image') if model_type == "Video" else None,
216
+
217
+ # XY Plot-related parameters
218
+ "x_param": job_params.get('x_param', None),
219
+ "y_param": job_params.get('y_param', None),
220
+ "x_values": job_params.get('x_values', None),
221
+ "y_values": job_params.get('y_values', None),
222
+
223
+ # Combine with source video
224
+ "combine_with_source": job_params.get('combine_with_source', False),
225
+
226
+ # Tea cache parameters
227
+ "use_teacache": job_params.get('use_teacache', False),
228
+ "teacache_num_steps": job_params.get('teacache_num_steps', 0),
229
+ "teacache_rel_l1_thresh": job_params.get('teacache_rel_l1_thresh', 0.0),
230
+ # MagCache parameters
231
+ "use_magcache": job_params.get('use_magcache', False),
232
+ "magcache_threshold": job_params.get('magcache_threshold', 0.1),
233
+ "magcache_max_consecutive_skips": job_params.get('magcache_max_consecutive_skips', 2),
234
+ "magcache_retention_ratio": job_params.get('magcache_retention_ratio', 0.25),
235
+ }
236
+
237
+ # Add LoRA information if present
238
+ selected_loras = job_params.get('selected_loras', [])
239
+ lora_values = job_params.get('lora_values', [])
240
+ lora_loaded_names = job_params.get('lora_loaded_names', [])
241
+
242
+ if isinstance(selected_loras, list) and len(selected_loras) > 0:
243
+ lora_data = {}
244
+ for lora_name in selected_loras:
245
+ try:
246
+ idx = lora_loaded_names.index(lora_name)
247
+ # Fix for NumPy array boolean ambiguity
248
+ has_lora_values = lora_values is not None and len(lora_values) > 0
249
+ weight = lora_values[idx] if has_lora_values and idx < len(lora_values) else 1.0
250
+
251
+ # Handle different types of weight values
252
+ if isinstance(weight, np.ndarray):
253
+ # Convert NumPy array to a scalar value
254
+ weight_value = float(weight.item()) if weight.size == 1 else float(weight.mean())
255
+ elif isinstance(weight, list):
256
+ # Handle list type weights
257
+ has_items = weight is not None and len(weight) > 0
258
+ weight_value = float(weight[0]) if has_items else 1.0
259
+ else:
260
+ # Handle scalar weights
261
+ weight_value = float(weight) if weight is not None else 1.0
262
+
263
+ lora_data[lora_name] = weight_value
264
+ except ValueError:
265
+ lora_data[lora_name] = 1.0
266
+ except Exception as e:
267
+ lora_data[lora_name] = 1.0
268
+ traceback.print_exc()
269
+
270
+ metadata_dict["loras"] = lora_data
271
+ else:
272
+ metadata_dict["loras"] = {}
273
+
274
+ # This function now only creates the metadata dictionary without saving files
275
+ # The actual saving is done by save_job_start_image() at the beginning of the generation process
276
+ # This prevents duplicate metadata files from being created
277
+
278
+ # For backward compatibility, we still create the placeholder image
279
+ # and save it if explicitly requested
280
+ placeholder_target_path = os.path.join(metadata_dir_path, f'{job_id}.png')
281
+
282
+ # Save the placeholder image if requested
283
+ if save_placeholder:
284
+ try:
285
+ placeholder_img.save(placeholder_target_path, pnginfo=metadata)
286
+ except Exception as e:
287
+ traceback.print_exc()
288
+
289
+ return metadata_dict
290
+
291
+ def save_last_video_frame(job_params, job_id, settings, last_frame_np):
292
+ """
293
+ Saves the last frame of the input video to the output directory with metadata.
294
+ """
295
+ output_dir_path = job_params.get("output_dir") or settings.get("output_dir")
296
+
297
+ if not output_dir_path:
298
+ print(f"[SAVE_LAST_FRAME_ERROR] No output directory found.")
299
+ return False
300
+
301
+ os.makedirs(output_dir_path, exist_ok=True)
302
+
303
+ last_frame_path = os.path.join(output_dir_path, f'{job_id}.png')
304
+
305
+ metadata_dict = create_metadata(job_params, job_id, settings)
306
+
307
+ if last_frame_np is not None and isinstance(last_frame_np, np.ndarray):
308
+ try:
309
+ png_metadata = PngInfo()
310
+ for key, value in metadata_dict.items():
311
+ if isinstance(value, (str, int, float, bool)) or value is None:
312
+ png_metadata.add_text(key, str(value))
313
+
314
+ image_to_save_np = last_frame_np
315
+ if last_frame_np.dtype != np.uint8:
316
+ if last_frame_np.max() <= 1.0 and last_frame_np.min() >= -1.0 and last_frame_np.dtype in [np.float32, np.float64]:
317
+ image_to_save_np = ((last_frame_np + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8)
318
+ elif last_frame_np.max() <= 1.0 and last_frame_np.min() >= 0.0 and last_frame_np.dtype in [np.float32, np.float64]:
319
+ image_to_save_np = (last_frame_np * 255.0).clip(0,255).astype(np.uint8)
320
+ else:
321
+ image_to_save_np = last_frame_np.clip(0, 255).astype(np.uint8)
322
+
323
+ last_frame_pil = Image.fromarray(image_to_save_np)
324
+ last_frame_pil.save(last_frame_path, pnginfo=png_metadata)
325
+ print(f"Saved last video frame for job {job_id} to {last_frame_path}")
326
+ return True
327
+ except Exception as e:
328
+ traceback.print_exc()
329
+ return False
modules/pipelines/original_pipeline.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Original pipeline class for FramePack Studio.
3
+ This pipeline handles the "Original" model type.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ from PIL import Image
11
+ from PIL.PngImagePlugin import PngInfo
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from .base_pipeline import BasePipeline
15
+
16
+ class OriginalPipeline(BasePipeline):
17
+ """Pipeline for Original generation type."""
18
+
19
+ def prepare_parameters(self, job_params):
20
+ """
21
+ Prepare parameters for the Original generation job.
22
+
23
+ Args:
24
+ job_params: Dictionary of job parameters
25
+
26
+ Returns:
27
+ Processed parameters dictionary
28
+ """
29
+ processed_params = job_params.copy()
30
+
31
+ # Ensure we have the correct model type
32
+ processed_params['model_type'] = "Original"
33
+
34
+ return processed_params
35
+
36
+ def validate_parameters(self, job_params):
37
+ """
38
+ Validate parameters for the Original generation job.
39
+
40
+ Args:
41
+ job_params: Dictionary of job parameters
42
+
43
+ Returns:
44
+ Tuple of (is_valid, error_message)
45
+ """
46
+ # Check for required parameters
47
+ required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
48
+ for param in required_params:
49
+ if param not in job_params:
50
+ return False, f"Missing required parameter: {param}"
51
+
52
+ # Validate numeric parameters
53
+ if job_params.get('total_second_length', 0) <= 0:
54
+ return False, "Video length must be greater than 0"
55
+
56
+ if job_params.get('steps', 0) <= 0:
57
+ return False, "Steps must be greater than 0"
58
+
59
+ return True, None
60
+
61
+ def preprocess_inputs(self, job_params):
62
+ """
63
+ Preprocess input images for the Original generation type.
64
+
65
+ Args:
66
+ job_params: Dictionary of job parameters
67
+
68
+ Returns:
69
+ Processed inputs dictionary
70
+ """
71
+ processed_inputs = {}
72
+
73
+ # Process input image if provided
74
+ input_image = job_params.get('input_image')
75
+ if input_image is not None:
76
+ # Get resolution parameters
77
+ resolutionW = job_params.get('resolutionW', 640)
78
+ resolutionH = job_params.get('resolutionH', 640)
79
+
80
+ # Find nearest bucket size
81
+ if job_params.get('has_input_image', True):
82
+ # If we have an input image, use its dimensions to find the nearest bucket
83
+ H, W, _ = input_image.shape
84
+ height, width = find_nearest_bucket(H, W, resolution=resolutionW)
85
+ else:
86
+ # Otherwise, use the provided resolution parameters
87
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
88
+
89
+ # Resize and center crop the input image
90
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
91
+
92
+ # Store the processed image and dimensions
93
+ processed_inputs['input_image'] = input_image_np
94
+ processed_inputs['height'] = height
95
+ processed_inputs['width'] = width
96
+ else:
97
+ # If no input image, create a blank image based on latent_type
98
+ resolutionW = job_params.get('resolutionW', 640)
99
+ resolutionH = job_params.get('resolutionH', 640)
100
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
101
+
102
+ latent_type = job_params.get('latent_type', 'Black')
103
+ if latent_type == "White":
104
+ # Create a white image
105
+ input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255
106
+ elif latent_type == "Noise":
107
+ # Create a noise image
108
+ input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
109
+ elif latent_type == "Green Screen":
110
+ # Create a green screen image with standard chroma key green (0, 177, 64)
111
+ input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
112
+ input_image_np[:, :, 1] = 177 # Green channel
113
+ input_image_np[:, :, 2] = 64 # Blue channel
114
+ # Red channel remains 0
115
+ else: # Default to "Black" or any other value
116
+ # Create a black image
117
+ input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
118
+
119
+ # Store the processed image and dimensions
120
+ processed_inputs['input_image'] = input_image_np
121
+ processed_inputs['height'] = height
122
+ processed_inputs['width'] = width
123
+
124
+ return processed_inputs
125
+
126
+ def handle_results(self, job_params, result):
127
+ """
128
+ Handle the results of the Original generation.
129
+
130
+ Args:
131
+ job_params: The job parameters
132
+ result: The generation result
133
+
134
+ Returns:
135
+ Processed result
136
+ """
137
+ # For Original generation, we just return the result as-is
138
+ return result
modules/pipelines/original_with_endframe_pipeline.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Original with Endframe pipeline class for FramePack Studio.
3
+ This pipeline handles the "Original with Endframe" model type.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ from PIL import Image
11
+ from PIL.PngImagePlugin import PngInfo
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from .base_pipeline import BasePipeline
15
+
16
+ class OriginalWithEndframePipeline(BasePipeline):
17
+ """Pipeline for Original with Endframe generation type."""
18
+
19
+ def prepare_parameters(self, job_params):
20
+ """
21
+ Prepare parameters for the Original with Endframe generation job.
22
+
23
+ Args:
24
+ job_params: Dictionary of job parameters
25
+
26
+ Returns:
27
+ Processed parameters dictionary
28
+ """
29
+ processed_params = job_params.copy()
30
+
31
+ # Ensure we have the correct model type
32
+ processed_params['model_type'] = "Original with Endframe"
33
+
34
+ return processed_params
35
+
36
+ def validate_parameters(self, job_params):
37
+ """
38
+ Validate parameters for the Original with Endframe generation job.
39
+
40
+ Args:
41
+ job_params: Dictionary of job parameters
42
+
43
+ Returns:
44
+ Tuple of (is_valid, error_message)
45
+ """
46
+ # Check for required parameters
47
+ required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
48
+ for param in required_params:
49
+ if param not in job_params:
50
+ return False, f"Missing required parameter: {param}"
51
+
52
+ # Validate numeric parameters
53
+ if job_params.get('total_second_length', 0) <= 0:
54
+ return False, "Video length must be greater than 0"
55
+
56
+ if job_params.get('steps', 0) <= 0:
57
+ return False, "Steps must be greater than 0"
58
+
59
+ # Validate end frame parameters
60
+ if job_params.get('end_frame_strength', 0) < 0 or job_params.get('end_frame_strength', 0) > 1:
61
+ return False, "End frame strength must be between 0 and 1"
62
+
63
+ return True, None
64
+
65
+ def preprocess_inputs(self, job_params):
66
+ """
67
+ Preprocess input images for the Original with Endframe generation type.
68
+
69
+ Args:
70
+ job_params: Dictionary of job parameters
71
+
72
+ Returns:
73
+ Processed inputs dictionary
74
+ """
75
+ processed_inputs = {}
76
+
77
+ # Process input image if provided
78
+ input_image = job_params.get('input_image')
79
+ if input_image is not None:
80
+ # Get resolution parameters
81
+ resolutionW = job_params.get('resolutionW', 640)
82
+ resolutionH = job_params.get('resolutionH', 640)
83
+
84
+ # Find nearest bucket size
85
+ if job_params.get('has_input_image', True):
86
+ # If we have an input image, use its dimensions to find the nearest bucket
87
+ H, W, _ = input_image.shape
88
+ height, width = find_nearest_bucket(H, W, resolution=resolutionW)
89
+ else:
90
+ # Otherwise, use the provided resolution parameters
91
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
92
+
93
+ # Resize and center crop the input image
94
+ input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
95
+
96
+ # Store the processed image and dimensions
97
+ processed_inputs['input_image'] = input_image_np
98
+ processed_inputs['height'] = height
99
+ processed_inputs['width'] = width
100
+ else:
101
+ # If no input image, create a blank image based on latent_type
102
+ resolutionW = job_params.get('resolutionW', 640)
103
+ resolutionH = job_params.get('resolutionH', 640)
104
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
105
+
106
+ latent_type = job_params.get('latent_type', 'Black')
107
+ if latent_type == "White":
108
+ # Create a white image
109
+ input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255
110
+ elif latent_type == "Noise":
111
+ # Create a noise image
112
+ input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
113
+ elif latent_type == "Green Screen":
114
+ # Create a green screen image with standard chroma key green (0, 177, 64)
115
+ input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
116
+ input_image_np[:, :, 1] = 177 # Green channel
117
+ input_image_np[:, :, 2] = 64 # Blue channel
118
+ # Red channel remains 0
119
+ else: # Default to "Black" or any other value
120
+ # Create a black image
121
+ input_image_np = np.zeros((height, width, 3), dtype=np.uint8)
122
+
123
+ # Store the processed image and dimensions
124
+ processed_inputs['input_image'] = input_image_np
125
+ processed_inputs['height'] = height
126
+ processed_inputs['width'] = width
127
+
128
+ # Process end frame image if provided
129
+ end_frame_image = job_params.get('end_frame_image')
130
+ if end_frame_image is not None:
131
+ # Use the same dimensions as the input image
132
+ height = processed_inputs['height']
133
+ width = processed_inputs['width']
134
+
135
+ # Resize and center crop the end frame image
136
+ end_frame_np = resize_and_center_crop(end_frame_image, target_width=width, target_height=height)
137
+
138
+ # Store the processed end frame image
139
+ processed_inputs['end_frame_image'] = end_frame_np
140
+
141
+ return processed_inputs
142
+
143
+ def handle_results(self, job_params, result):
144
+ """
145
+ Handle the results of the Original with Endframe generation.
146
+
147
+ Args:
148
+ job_params: The job parameters
149
+ result: The generation result
150
+
151
+ Returns:
152
+ Processed result
153
+ """
154
+ # For Original with Endframe generation, we just return the result as-is
155
+ return result
156
+
157
+ # Using the centralized create_metadata method from BasePipeline
modules/pipelines/video_f1_pipeline.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Video F1 pipeline class for FramePack Studio.
3
+ This pipeline handles the "Video F1" model type.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ from PIL import Image
11
+ from PIL.PngImagePlugin import PngInfo
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from .base_pipeline import BasePipeline
15
+
16
+ class VideoF1Pipeline(BasePipeline):
17
+ """Pipeline for Video F1 generation type."""
18
+
19
+ def prepare_parameters(self, job_params):
20
+ """
21
+ Prepare parameters for the Video generation job.
22
+
23
+ Args:
24
+ job_params: Dictionary of job parameters
25
+
26
+ Returns:
27
+ Processed parameters dictionary
28
+ """
29
+ processed_params = job_params.copy()
30
+
31
+ # Ensure we have the correct model type
32
+ processed_params['model_type'] = "Video F1"
33
+
34
+ return processed_params
35
+
36
+ def validate_parameters(self, job_params):
37
+ """
38
+ Validate parameters for the Video generation job.
39
+
40
+ Args:
41
+ job_params: Dictionary of job parameters
42
+
43
+ Returns:
44
+ Tuple of (is_valid, error_message)
45
+ """
46
+ # Check for required parameters
47
+ required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
48
+ for param in required_params:
49
+ if param not in job_params:
50
+ return False, f"Missing required parameter: {param}"
51
+
52
+ # Validate numeric parameters
53
+ if job_params.get('total_second_length', 0) <= 0:
54
+ return False, "Video length must be greater than 0"
55
+
56
+ if job_params.get('steps', 0) <= 0:
57
+ return False, "Steps must be greater than 0"
58
+
59
+ # Check for input video (stored in input_image for Video F1 model)
60
+ if not job_params.get('input_image'):
61
+ return False, "Input video is required for Video F1 model"
62
+
63
+ # Check if combine_with_source is provided (optional)
64
+ combine_with_source = job_params.get('combine_with_source')
65
+ if combine_with_source is not None and not isinstance(combine_with_source, bool):
66
+ return False, "combine_with_source must be a boolean value"
67
+
68
+ return True, None
69
+
70
+ def preprocess_inputs(self, job_params):
71
+ """
72
+ Preprocess input video for the Video F1 generation type.
73
+
74
+ Args:
75
+ job_params: Dictionary of job parameters
76
+
77
+ Returns:
78
+ Processed inputs dictionary
79
+ """
80
+ processed_inputs = {}
81
+
82
+ # Get the input video (stored in input_image for Video F1 model)
83
+ input_video = job_params.get('input_image')
84
+ if not input_video:
85
+ raise ValueError("Input video is required for Video F1 model")
86
+
87
+ # Store the input video
88
+ processed_inputs['input_video'] = input_video
89
+
90
+ # Note: The following code will be executed in the worker function:
91
+ # 1. The worker will call video_encode on the generator to get video_latents and input_video_pixels
92
+ # 2. Then it will store these values for later use:
93
+ # input_video_pixels = input_video_pixels.cpu()
94
+ # video_latents = video_latents.cpu()
95
+ #
96
+ # 3. If the generator has the set_full_video_latents method, it will store the video latents:
97
+ # if hasattr(current_generator, 'set_full_video_latents'):
98
+ # current_generator.set_full_video_latents(video_latents.clone())
99
+ # print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}")
100
+ #
101
+ # 4. For the Video model, history_latents is initialized with the video_latents:
102
+ # history_latents = video_latents
103
+ # print(f"Initialized history_latents with video context. Shape: {history_latents.shape}")
104
+ processed_inputs['input_files_dir'] = job_params.get('input_files_dir')
105
+
106
+ # Pass through the combine_with_source parameter if it exists
107
+ if 'combine_with_source' in job_params:
108
+ processed_inputs['combine_with_source'] = job_params.get('combine_with_source')
109
+ print(f"Video F1 pipeline: combine_with_source = {processed_inputs['combine_with_source']}")
110
+
111
+ # Pass through the num_cleaned_frames parameter if it exists
112
+ if 'num_cleaned_frames' in job_params:
113
+ processed_inputs['num_cleaned_frames'] = job_params.get('num_cleaned_frames')
114
+ print(f"Video F1 pipeline: num_cleaned_frames = {processed_inputs['num_cleaned_frames']}")
115
+
116
+ # Get resolution parameters
117
+ resolutionW = job_params.get('resolutionW', 640)
118
+ resolutionH = job_params.get('resolutionH', 640)
119
+
120
+ # Find nearest bucket size
121
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
122
+
123
+ # Store the dimensions
124
+ processed_inputs['height'] = height
125
+ processed_inputs['width'] = width
126
+
127
+ return processed_inputs
128
+
129
+ def handle_results(self, job_params, result):
130
+ """
131
+ Handle the results of the Video F1 generation.
132
+
133
+ Args:
134
+ job_params: The job parameters
135
+ result: The generation result
136
+
137
+ Returns:
138
+ Processed result
139
+ """
140
+ # For Video F1 generation, we just return the result as-is
141
+ return result
142
+
143
+ # Using the centralized create_metadata method from BasePipeline
modules/pipelines/video_pipeline.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Video pipeline class for FramePack Studio.
3
+ This pipeline handles the "Video" model type.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ from PIL import Image
11
+ from PIL.PngImagePlugin import PngInfo
12
+ from diffusers_helper.utils import resize_and_center_crop
13
+ from diffusers_helper.bucket_tools import find_nearest_bucket
14
+ from .base_pipeline import BasePipeline
15
+
16
+ class VideoPipeline(BasePipeline):
17
+ """Pipeline for Video generation type."""
18
+
19
+ def prepare_parameters(self, job_params):
20
+ """
21
+ Prepare parameters for the Video generation job.
22
+
23
+ Args:
24
+ job_params: Dictionary of job parameters
25
+
26
+ Returns:
27
+ Processed parameters dictionary
28
+ """
29
+ processed_params = job_params.copy()
30
+
31
+ # Ensure we have the correct model type
32
+ processed_params['model_type'] = "Video"
33
+
34
+ return processed_params
35
+
36
+ def validate_parameters(self, job_params):
37
+ """
38
+ Validate parameters for the Video generation job.
39
+
40
+ Args:
41
+ job_params: Dictionary of job parameters
42
+
43
+ Returns:
44
+ Tuple of (is_valid, error_message)
45
+ """
46
+ # Check for required parameters
47
+ required_params = ['prompt_text', 'seed', 'total_second_length', 'steps']
48
+ for param in required_params:
49
+ if param not in job_params:
50
+ return False, f"Missing required parameter: {param}"
51
+
52
+ # Validate numeric parameters
53
+ if job_params.get('total_second_length', 0) <= 0:
54
+ return False, "Video length must be greater than 0"
55
+
56
+ if job_params.get('steps', 0) <= 0:
57
+ return False, "Steps must be greater than 0"
58
+
59
+ # Check for input video (stored in input_image for Video model)
60
+ if not job_params.get('input_image'):
61
+ return False, "Input video is required for Video model"
62
+
63
+ # Check if combine_with_source is provided (optional)
64
+ combine_with_source = job_params.get('combine_with_source')
65
+ if combine_with_source is not None and not isinstance(combine_with_source, bool):
66
+ return False, "combine_with_source must be a boolean value"
67
+
68
+ return True, None
69
+
70
+ def preprocess_inputs(self, job_params):
71
+ """
72
+ Preprocess input video for the Video generation type.
73
+
74
+ Args:
75
+ job_params: Dictionary of job parameters
76
+
77
+ Returns:
78
+ Processed inputs dictionary
79
+ """
80
+ processed_inputs = {}
81
+
82
+ # Get the input video (stored in input_image for Video model)
83
+ input_video = job_params.get('input_image')
84
+ if not input_video:
85
+ raise ValueError("Input video is required for Video model")
86
+
87
+ # Store the input video
88
+ processed_inputs['input_video'] = input_video
89
+
90
+ # Note: The following code will be executed in the worker function:
91
+ # 1. The worker will call video_encode on the generator to get video_latents and input_video_pixels
92
+ # 2. Then it will store these values for later use:
93
+ # input_video_pixels = input_video_pixels.cpu()
94
+ # video_latents = video_latents.cpu()
95
+ #
96
+ # 3. If the generator has the set_full_video_latents method, it will store the video latents:
97
+ # if hasattr(current_generator, 'set_full_video_latents'):
98
+ # current_generator.set_full_video_latents(video_latents.clone())
99
+ # print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}")
100
+ #
101
+ # 4. For the Video model, history_latents is initialized with the video_latents:
102
+ # history_latents = video_latents
103
+ # print(f"Initialized history_latents with video context. Shape: {history_latents.shape}")
104
+ processed_inputs['input_files_dir'] = job_params.get('input_files_dir')
105
+
106
+ # Pass through the combine_with_source parameter if it exists
107
+ if 'combine_with_source' in job_params:
108
+ processed_inputs['combine_with_source'] = job_params.get('combine_with_source')
109
+ print(f"Video pipeline: combine_with_source = {processed_inputs['combine_with_source']}")
110
+
111
+ # Pass through the num_cleaned_frames parameter if it exists
112
+ if 'num_cleaned_frames' in job_params:
113
+ processed_inputs['num_cleaned_frames'] = job_params.get('num_cleaned_frames')
114
+ print(f"Video pipeline: num_cleaned_frames = {processed_inputs['num_cleaned_frames']}")
115
+
116
+ # Get resolution parameters
117
+ resolutionW = job_params.get('resolutionW', 640)
118
+ resolutionH = job_params.get('resolutionH', 640)
119
+
120
+ # Find nearest bucket size
121
+ height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2)
122
+
123
+ # Store the dimensions
124
+ processed_inputs['height'] = height
125
+ processed_inputs['width'] = width
126
+
127
+ return processed_inputs
128
+
129
+ def handle_results(self, job_params, result):
130
+ """
131
+ Handle the results of the Video generation.
132
+
133
+ Args:
134
+ job_params: The job parameters
135
+ result: The generation result
136
+
137
+ Returns:
138
+ Processed result
139
+ """
140
+ # For Video generation, we just return the result as-is
141
+ return result
142
+
143
+ # Using the centralized create_metadata method from BasePipeline
modules/pipelines/video_tools.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import traceback
4
+
5
+ from diffusers_helper.utils import save_bcthw_as_mp4
6
+
7
+ @torch.no_grad()
8
+ def combine_videos_sequentially_from_tensors(processed_input_frames_np,
9
+ generated_frames_pt,
10
+ output_path,
11
+ target_fps,
12
+ crf_value):
13
+ """
14
+ Combines processed input frames (NumPy) with generated frames (PyTorch Tensor) sequentially
15
+ and saves the result as an MP4 video using save_bcthw_as_mp4.
16
+
17
+ Args:
18
+ processed_input_frames_np: NumPy array of processed input frames (T_in, H, W_in, C), uint8.
19
+ generated_frames_pt: PyTorch tensor of generated frames (B_gen, C_gen, T_gen, H, W_gen), float32 [-1,1].
20
+ (This will be history_pixels from worker.py)
21
+ output_path: Path to save the combined video.
22
+ target_fps: FPS for the output combined video.
23
+ crf_value: CRF value for video encoding.
24
+
25
+ Returns:
26
+ Path to the combined video, or None if an error occurs.
27
+ """
28
+ try:
29
+ # 1. Convert processed_input_frames_np to PyTorch tensor BCTHW, float32, [-1,1]
30
+ # processed_input_frames_np shape: (T_in, H, W_in, C)
31
+ input_frames_pt = torch.from_numpy(processed_input_frames_np).float() / 127.5 - 1.0 # (T,H,W,C)
32
+ input_frames_pt = input_frames_pt.permute(3, 0, 1, 2) # (C,T,H,W)
33
+ input_frames_pt = input_frames_pt.unsqueeze(0) # (1,C,T,H,W) -> BCTHW
34
+
35
+ # Ensure generated_frames_pt is on the same device and dtype for concatenation
36
+ input_frames_pt = input_frames_pt.to(device=generated_frames_pt.device, dtype=generated_frames_pt.dtype)
37
+
38
+ # 2. Dimension Check (Heights and Widths should match)
39
+ # They should match, since the input frames should have been processed to match the generation resolution.
40
+ # But sanity check to ensure no mismatch occurs when the code is refactored.
41
+ if input_frames_pt.shape[3:] != generated_frames_pt.shape[3:]: # Compare (H,W)
42
+ print(f"Warning: Dimension mismatch for sequential combination! Input: {input_frames_pt.shape[3:]}, Generated: {generated_frames_pt.shape[3:]}.")
43
+ print("Attempting to proceed, but this might lead to errors or unexpected video output.")
44
+ # Potentially add resizing logic here if necessary, but for now, assume they match
45
+
46
+ # 3. Concatenate Tensors along the time dimension (dim=2 for BCTHW)
47
+ combined_video_pt = torch.cat([input_frames_pt, generated_frames_pt], dim=2)
48
+
49
+ # 4. Save
50
+ save_bcthw_as_mp4(combined_video_pt, output_path, fps=target_fps, crf=crf_value)
51
+ print(f"Sequentially combined video (from tensors) saved to {output_path}")
52
+ return output_path
53
+ except Exception as e:
54
+ print(f"Error in combine_videos_sequentially_from_tensors: {str(e)}")
55
+ import traceback
56
+ traceback.print_exc()
57
+ return None
modules/pipelines/worker.py ADDED
@@ -0,0 +1,1150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import traceback
5
+ import einops
6
+ import numpy as np
7
+ import torch
8
+ import datetime
9
+ from PIL import Image
10
+ from PIL.PngImagePlugin import PngInfo
11
+ from diffusers_helper.models.mag_cache import MagCache
12
+ from diffusers_helper.utils import save_bcthw_as_mp4, generate_timestamp, resize_and_center_crop
13
+ from diffusers_helper.memory import cpu, gpu, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, unload_complete_models, load_model_as_complete
14
+ from diffusers_helper.thread_utils import AsyncStream
15
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_html
16
+ from diffusers_helper.hunyuan import vae_decode
17
+ from modules.video_queue import JobStatus
18
+ from modules.prompt_handler import parse_timestamped_prompt
19
+ from modules.generators import create_model_generator
20
+ from modules.pipelines.video_tools import combine_videos_sequentially_from_tensors
21
+ from modules import DUMMY_LORA_NAME # Import the constant
22
+ from modules.llm_captioner import unload_captioning_model
23
+ from modules.llm_enhancer import unload_enhancing_model
24
+ from . import create_pipeline
25
+
26
+ import __main__ as studio_module # Get a reference to the __main__ module object
27
+
28
+ @torch.no_grad()
29
+ def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device, prompt_embedding_cache):
30
+ """
31
+ Retrieves prompt embeddings from cache or encodes them if not found.
32
+ Stores encoded embeddings (on CPU) in the cache.
33
+ Returns embeddings moved to the target_device.
34
+ """
35
+ from diffusers_helper.hunyuan import encode_prompt_conds, crop_or_pad_yield_mask
36
+
37
+ if prompt in prompt_embedding_cache:
38
+ print(f"Cache hit for prompt: {prompt[:60]}...")
39
+ llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt]
40
+ # Move cached embeddings (from CPU) to the target device
41
+ llama_vec = llama_vec_cpu.to(target_device)
42
+ llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None
43
+ clip_l_pooler = clip_l_pooler_cpu.to(target_device)
44
+ return llama_vec, llama_attention_mask, clip_l_pooler
45
+ else:
46
+ print(f"Cache miss for prompt: {prompt[:60]}...")
47
+ llama_vec, clip_l_pooler = encode_prompt_conds(
48
+ prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
49
+ )
50
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
51
+ # Store CPU copies in cache
52
+ prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu())
53
+ # Return embeddings already on the target device (as encode_prompt_conds uses the model's device)
54
+ return llama_vec, llama_attention_mask, clip_l_pooler
55
+
56
+ @torch.no_grad()
57
+ def worker(
58
+ model_type,
59
+ input_image,
60
+ end_frame_image, # The end frame image (numpy array or None)
61
+ end_frame_strength, # Influence of the end frame
62
+ prompt_text,
63
+ n_prompt,
64
+ seed,
65
+ total_second_length,
66
+ latent_window_size,
67
+ steps,
68
+ cfg,
69
+ gs,
70
+ rs,
71
+ use_teacache,
72
+ teacache_num_steps,
73
+ teacache_rel_l1_thresh,
74
+ use_magcache,
75
+ magcache_threshold,
76
+ magcache_max_consecutive_skips,
77
+ magcache_retention_ratio,
78
+ blend_sections,
79
+ latent_type,
80
+ selected_loras,
81
+ has_input_image,
82
+ lora_values=None,
83
+ job_stream=None,
84
+ output_dir=None,
85
+ metadata_dir=None,
86
+ input_files_dir=None, # Add input_files_dir parameter
87
+ input_image_path=None, # Add input_image_path parameter
88
+ end_frame_image_path=None, # Add end_frame_image_path parameter
89
+ resolutionW=640, # Add resolution parameter with default value
90
+ resolutionH=640,
91
+ lora_loaded_names=[],
92
+ input_video=None, # Add input_video parameter with default value of None
93
+ combine_with_source=None, # Add combine_with_source parameter
94
+ num_cleaned_frames=5, # Add num_cleaned_frames parameter with default value
95
+ save_metadata_checked=True # Add save_metadata_checked parameter
96
+ ):
97
+ """
98
+ Worker function for video generation.
99
+ """
100
+
101
+ random_generator = torch.Generator("cpu").manual_seed(seed)
102
+
103
+ unload_enhancing_model()
104
+ unload_captioning_model()
105
+
106
+ # Filter out the dummy LoRA from selected_loras at the very beginning of the worker
107
+ actual_selected_loras_for_worker = []
108
+ if isinstance(selected_loras, list):
109
+ actual_selected_loras_for_worker = [lora for lora in selected_loras if lora != DUMMY_LORA_NAME]
110
+ if DUMMY_LORA_NAME in selected_loras and DUMMY_LORA_NAME in actual_selected_loras_for_worker: # Should not happen if filter works
111
+ print(f"Worker.py: Error - '{DUMMY_LORA_NAME}' was selected but not filtered out.")
112
+ elif DUMMY_LORA_NAME in selected_loras:
113
+ print(f"Worker.py: Filtered out '{DUMMY_LORA_NAME}' from selected LoRAs.")
114
+ elif selected_loras is not None: # If it's a single string (should not happen with multiselect dropdown)
115
+ if selected_loras != DUMMY_LORA_NAME:
116
+ actual_selected_loras_for_worker = [selected_loras]
117
+ selected_loras = actual_selected_loras_for_worker
118
+ print(f"Worker: Selected LoRAs for this worker: {selected_loras}")
119
+
120
+ # Import globals from the main module
121
+ from __main__ import high_vram, args, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, image_encoder, feature_extractor, prompt_embedding_cache, settings, stream
122
+
123
+ # Ensure any existing LoRAs are unloaded from the current generator
124
+ if studio_module.current_generator is not None:
125
+ print("Worker: Unloading LoRAs from studio_module.current_generator")
126
+ studio_module.current_generator.unload_loras()
127
+ import gc
128
+ gc.collect()
129
+ if torch.cuda.is_available():
130
+ torch.cuda.empty_cache()
131
+
132
+ stream_to_use = job_stream if job_stream is not None else stream
133
+
134
+ total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
135
+ total_latent_sections = int(max(round(total_latent_sections), 1))
136
+
137
+ # --- Total progress tracking ---
138
+ total_steps = total_latent_sections * steps # Total diffusion steps over all segments
139
+ step_durations = [] # Rolling history of recent step durations for ETA
140
+ last_step_time = time.time()
141
+
142
+ # Parse the timestamped prompt with boundary snapping and reversing
143
+ # prompt_text should now be the original string from the job queue
144
+ prompt_sections = parse_timestamped_prompt(prompt_text, total_second_length, latent_window_size, model_type)
145
+ job_id = generate_timestamp()
146
+
147
+ # Initialize progress data with a clear starting message and dummy preview
148
+ dummy_preview = np.zeros((64, 64, 3), dtype=np.uint8)
149
+ initial_progress_data = {
150
+ 'preview': dummy_preview,
151
+ 'desc': 'Starting job...',
152
+ 'html': make_progress_bar_html(0, 'Starting job...')
153
+ }
154
+
155
+ # Store initial progress data in the job object if using a job stream
156
+ if job_stream is not None:
157
+ try:
158
+ from __main__ import job_queue
159
+ job = job_queue.get_job(job_id)
160
+ if job:
161
+ job.progress_data = initial_progress_data
162
+ except Exception as e:
163
+ print(f"Error storing initial progress data: {e}")
164
+
165
+ # Push initial progress update to both streams
166
+ stream_to_use.output_queue.push(('progress', (dummy_preview, 'Starting job...', make_progress_bar_html(0, 'Starting job...'))))
167
+
168
+ # Push job ID to stream to ensure monitoring connection
169
+ stream_to_use.output_queue.push(('job_id', job_id))
170
+ stream_to_use.output_queue.push(('monitor_job', job_id))
171
+
172
+ # Always push to the main stream to ensure the UI is updated
173
+ from __main__ import stream as main_stream
174
+ if main_stream: # Always push to main stream regardless of whether it's the same as stream_to_use
175
+ print(f"Pushing initial progress update to main stream for job {job_id}")
176
+ main_stream.output_queue.push(('progress', (dummy_preview, 'Starting job...', make_progress_bar_html(0, 'Starting job...'))))
177
+
178
+ # Push job ID to main stream to ensure monitoring connection
179
+ main_stream.output_queue.push(('job_id', job_id))
180
+ main_stream.output_queue.push(('monitor_job', job_id))
181
+
182
+ try:
183
+ # Create a settings dictionary for the pipeline
184
+ pipeline_settings = {
185
+ "output_dir": output_dir,
186
+ "metadata_dir": metadata_dir,
187
+ "input_files_dir": input_files_dir,
188
+ "save_metadata": settings.get("save_metadata", True),
189
+ "gpu_memory_preservation": settings.get("gpu_memory_preservation", 6),
190
+ "mp4_crf": settings.get("mp4_crf", 16),
191
+ "clean_up_videos": settings.get("clean_up_videos", True),
192
+ "gradio_temp_dir": settings.get("gradio_temp_dir", "./gradio_temp"),
193
+ "high_vram": high_vram
194
+ }
195
+
196
+ # Create the appropriate pipeline for the model type
197
+ pipeline = create_pipeline(model_type, pipeline_settings)
198
+
199
+ # Create job parameters dictionary
200
+ job_params = {
201
+ 'model_type': model_type,
202
+ 'input_image': input_image,
203
+ 'end_frame_image': end_frame_image,
204
+ 'end_frame_strength': end_frame_strength,
205
+ 'prompt_text': prompt_text,
206
+ 'n_prompt': n_prompt,
207
+ 'seed': seed,
208
+ 'total_second_length': total_second_length,
209
+ 'latent_window_size': latent_window_size,
210
+ 'steps': steps,
211
+ 'cfg': cfg,
212
+ 'gs': gs,
213
+ 'rs': rs,
214
+ 'blend_sections': blend_sections,
215
+ 'latent_type': latent_type,
216
+ 'use_teacache': use_teacache,
217
+ 'teacache_num_steps': teacache_num_steps,
218
+ 'teacache_rel_l1_thresh': teacache_rel_l1_thresh,
219
+ 'use_magcache': use_magcache,
220
+ 'magcache_threshold': magcache_threshold,
221
+ 'magcache_max_consecutive_skips': magcache_max_consecutive_skips,
222
+ 'magcache_retention_ratio': magcache_retention_ratio,
223
+ 'selected_loras': selected_loras,
224
+ 'has_input_image': has_input_image,
225
+ 'lora_values': lora_values,
226
+ 'resolutionW': resolutionW,
227
+ 'resolutionH': resolutionH,
228
+ 'lora_loaded_names': lora_loaded_names,
229
+ 'input_image_path': input_image_path,
230
+ 'end_frame_image_path': end_frame_image_path,
231
+ 'combine_with_source': combine_with_source,
232
+ 'num_cleaned_frames': num_cleaned_frames,
233
+ 'save_metadata_checked': save_metadata_checked # Ensure it's in job_params for internal use
234
+ }
235
+
236
+ # Validate parameters
237
+ is_valid, error_message = pipeline.validate_parameters(job_params)
238
+ if not is_valid:
239
+ raise ValueError(f"Invalid parameters: {error_message}")
240
+
241
+ # Prepare parameters
242
+ job_params = pipeline.prepare_parameters(job_params)
243
+
244
+ if not high_vram:
245
+ # Unload everything *except* the potentially active transformer
246
+ unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae)
247
+ if studio_module.current_generator is not None and studio_module.current_generator.transformer is not None:
248
+ offload_model_from_device_for_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=8)
249
+
250
+
251
+ # --- Model Loading / Switching ---
252
+ print(f"Worker starting for model type: {model_type}")
253
+ print(f"Worker: Before model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}")
254
+
255
+ # Create the appropriate model generator
256
+ new_generator = create_model_generator(
257
+ model_type,
258
+ text_encoder=text_encoder,
259
+ text_encoder_2=text_encoder_2,
260
+ tokenizer=tokenizer,
261
+ tokenizer_2=tokenizer_2,
262
+ vae=vae,
263
+ image_encoder=image_encoder,
264
+ feature_extractor=feature_extractor,
265
+ high_vram=high_vram,
266
+ prompt_embedding_cache=prompt_embedding_cache,
267
+ offline=args.offline,
268
+ settings=settings
269
+ )
270
+
271
+ # Update the global generator
272
+ # This modifies the 'current_generator' attribute OF THE '__main__' MODULE OBJECT
273
+ studio_module.current_generator = new_generator
274
+ print(f"Worker: AFTER model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}")
275
+ if studio_module.current_generator:
276
+ print(f"Worker: studio_module.current_generator.transformer is {type(studio_module.current_generator.transformer)}")
277
+
278
+ # Load the transformer model
279
+ studio_module.current_generator.load_model()
280
+
281
+ # Ensure the model has no LoRAs loaded
282
+ print(f"Ensuring {model_type} model has no LoRAs loaded")
283
+ studio_module.current_generator.unload_loras()
284
+
285
+ # Preprocess inputs
286
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Preprocessing inputs...'))))
287
+ processed_inputs = pipeline.preprocess_inputs(job_params)
288
+
289
+ # Update job_params with processed inputs
290
+ job_params.update(processed_inputs)
291
+
292
+ # Save the starting image directly to the output directory with full metadata
293
+ # Check both global settings and job-specific save_metadata_checked parameter
294
+ if settings.get("save_metadata") and job_params.get('save_metadata_checked', True) and job_params.get('input_image') is not None:
295
+ try:
296
+ # Import the save_job_start_image function from metadata_utils
297
+ from modules.pipelines.metadata_utils import save_job_start_image, create_metadata
298
+
299
+ # Create comprehensive metadata for the job
300
+ metadata_dict = create_metadata(job_params, job_id, settings)
301
+
302
+ # Save the starting image with metadata
303
+ save_job_start_image(job_params, job_id, settings)
304
+
305
+ print(f"Saved metadata and starting image for job {job_id}")
306
+ except Exception as e:
307
+ print(f"Error saving starting image and metadata: {e}")
308
+ traceback.print_exc()
309
+
310
+ # Pre-encode all prompts
311
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding all prompts...'))))
312
+
313
+ # THE FOLLOWING CODE SHOULD BE INSIDE THE TRY BLOCK
314
+ if not high_vram:
315
+ fake_diffusers_current_device(text_encoder, gpu)
316
+ load_model_as_complete(text_encoder_2, target_device=gpu)
317
+
318
+ # PROMPT BLENDING: Pre-encode all prompts and store in a list in order
319
+ unique_prompts = []
320
+ for section in prompt_sections:
321
+ if section.prompt not in unique_prompts:
322
+ unique_prompts.append(section.prompt)
323
+
324
+ encoded_prompts = {}
325
+ for prompt in unique_prompts:
326
+ # Use the helper function for caching and encoding
327
+ llama_vec, llama_attention_mask, clip_l_pooler = get_cached_or_encode_prompt(
328
+ prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu, prompt_embedding_cache
329
+ )
330
+ encoded_prompts[prompt] = (llama_vec, llama_attention_mask, clip_l_pooler)
331
+
332
+ # PROMPT BLENDING: Build a list of (start_section_idx, prompt) for each prompt
333
+ prompt_change_indices = []
334
+ last_prompt = None
335
+ for idx, section in enumerate(prompt_sections):
336
+ if section.prompt != last_prompt:
337
+ prompt_change_indices.append((idx, section.prompt))
338
+ last_prompt = section.prompt
339
+
340
+ # Encode negative prompt
341
+ if cfg == 1:
342
+ llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = (
343
+ torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][0]),
344
+ torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][1]),
345
+ torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][2])
346
+ )
347
+ else:
348
+ # Use the helper function for caching and encoding negative prompt
349
+ # Ensure n_prompt is a string
350
+ n_prompt_str = str(n_prompt) if n_prompt is not None else ""
351
+ llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = get_cached_or_encode_prompt(
352
+ n_prompt_str, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu, prompt_embedding_cache
353
+ )
354
+
355
+ end_of_input_video_embedding = None # Video model end frame CLIP Vision embedding
356
+ # Process input image or video based on model type
357
+ if model_type == "Video" or model_type == "Video F1":
358
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Video processing ...'))))
359
+
360
+ # Encode the video using the VideoModelGenerator
361
+ start_latent, input_image_np, video_latents, fps, height, width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np = studio_module.current_generator.video_encode(
362
+ video_path=job_params['input_image'], # For Video model, input_image contains the video path
363
+ resolution=job_params['resolutionW'],
364
+ no_resize=False,
365
+ vae_batch_size=16,
366
+ device=gpu,
367
+ input_files_dir=job_params['input_files_dir']
368
+ )
369
+
370
+ if end_of_input_video_image_np is not None:
371
+ try:
372
+ from modules.pipelines.metadata_utils import save_last_video_frame
373
+ save_last_video_frame(job_params, job_id, settings, end_of_input_video_image_np)
374
+ except Exception as e:
375
+ print(f"Error saving last video frame: {e}")
376
+ traceback.print_exc()
377
+
378
+ # RT_BORG: retained only until we make our final decisions on how to handle combining videos
379
+ # Only necessary to retain resized frames to produce a combined video with source frames of the right dimensions
380
+ #if combine_with_source:
381
+ # # Store input_frames_resized_np in job_params for later use
382
+ # job_params['input_frames_resized_np'] = input_frames_resized_np
383
+
384
+ # CLIP Vision encoding for the first frame
385
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
386
+
387
+ if not high_vram:
388
+ load_model_as_complete(image_encoder, target_device=gpu)
389
+
390
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
391
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
392
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
393
+
394
+ end_of_input_video_embedding = hf_clip_vision_encode(end_of_input_video_image_np, feature_extractor, image_encoder).last_hidden_state
395
+
396
+ # Store the input video pixels and latents for later use
397
+ input_video_pixels = input_video_pixels.cpu()
398
+ video_latents = video_latents.cpu()
399
+
400
+ # Store the full video latents in the generator instance for preparing clean latents
401
+ if hasattr(studio_module.current_generator, 'set_full_video_latents'):
402
+ studio_module.current_generator.set_full_video_latents(video_latents.clone())
403
+ print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}")
404
+
405
+ # For Video model, history_latents is initialized with the video_latents
406
+ history_latents = video_latents
407
+
408
+ # Store the last frame of the video latents as start_latent for the model
409
+ start_latent = video_latents[:, :, -1:].cpu()
410
+ print(f"Using last frame of input video as start_latent. Shape: {start_latent.shape}")
411
+ print(f"Placed last frame of video at position 0 in history_latents")
412
+
413
+ print(f"Initialized history_latents with video context. Shape: {history_latents.shape}")
414
+
415
+ # Store the number of frames in the input video for later use
416
+ input_video_frame_count = video_latents.shape[2]
417
+ else:
418
+ # Regular image processing
419
+ height = job_params['height']
420
+ width = job_params['width']
421
+
422
+ if not has_input_image and job_params.get('latent_type') == 'Noise':
423
+ # print("************************************************")
424
+ # print("** Using 'Noise' latent type for T2V workflow **")
425
+ # print("************************************************")
426
+
427
+ # Create a random latent to serve as the initial VAE context anchor.
428
+ # This provides a random starting point without visual bias.
429
+ start_latent = torch.randn(
430
+ (1, 16, 1, height // 8, width // 8),
431
+ generator=random_generator, device=random_generator.device
432
+ ).to(device=gpu, dtype=torch.float32)
433
+
434
+ # Create a neutral black image to generate a valid "null" CLIP Vision embedding.
435
+ # This provides the model with a valid, in-distribution unconditional image prompt.
436
+ # RT_BORG: Clip doesn't understand noise at all. I also tried using
437
+ # image_encoder_last_hidden_state = torch.zeros((1, 257, 1152), device=gpu, dtype=studio_module.current_generator.transformer.dtype)
438
+ # to represent a "null" CLIP Vision embedding in the shape for the CLIP encoder,
439
+ # but the Video model wasn't trained to handle zeros, so using a neutral black image for CLIP.
440
+
441
+ black_image_np = np.zeros((height, width, 3), dtype=np.uint8)
442
+
443
+ if not high_vram:
444
+ load_model_as_complete(image_encoder, target_device=gpu)
445
+
446
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
447
+ image_encoder_output = hf_clip_vision_encode(black_image_np, feature_extractor, image_encoder)
448
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
449
+
450
+ else:
451
+ input_image_np = job_params['input_image']
452
+
453
+ input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
454
+ input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
455
+
456
+ # Start image encoding with VAE
457
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
458
+
459
+ if not high_vram:
460
+ load_model_as_complete(vae, target_device=gpu)
461
+
462
+ from diffusers_helper.hunyuan import vae_encode
463
+ start_latent = vae_encode(input_image_pt, vae)
464
+
465
+ # CLIP Vision
466
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
467
+
468
+ if not high_vram:
469
+ load_model_as_complete(image_encoder, target_device=gpu)
470
+
471
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
472
+ image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
473
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
474
+
475
+ # VAE encode end_frame_image if provided
476
+ end_frame_latent = None
477
+ # VAE encode end_frame_image resized to output dimensions, if provided
478
+ end_frame_output_dimensions_latent = None
479
+ end_clip_embedding = None # Video model end frame CLIP Vision embedding
480
+
481
+ # Models with end_frame_image processing
482
+ if (model_type == "Original with Endframe" or model_type == "Video") and job_params.get('end_frame_image') is not None:
483
+ print(f"Processing end frame for {model_type} model...")
484
+ end_frame_image = job_params['end_frame_image']
485
+
486
+ if not isinstance(end_frame_image, np.ndarray):
487
+ print(f"Warning: end_frame_image is not a numpy array (type: {type(end_frame_image)}). Attempting conversion or skipping.")
488
+ try:
489
+ end_frame_image = np.array(end_frame_image)
490
+ except Exception as e_conv:
491
+ print(f"Could not convert end_frame_image to numpy array: {e_conv}. Skipping end frame.")
492
+ end_frame_image = None
493
+
494
+ if end_frame_image is not None:
495
+ # Use the main job's target width/height (bucket dimensions) for the end frame
496
+ end_frame_np = job_params['end_frame_image']
497
+
498
+ if settings.get("save_metadata"):
499
+ Image.fromarray(end_frame_np).save(os.path.join(metadata_dir, f'{job_id}_end_frame_processed.png'))
500
+
501
+ end_frame_pt = torch.from_numpy(end_frame_np).float() / 127.5 - 1
502
+ end_frame_pt = end_frame_pt.permute(2, 0, 1)[None, :, None] # VAE expects [B, C, F, H, W]
503
+
504
+ if not high_vram: load_model_as_complete(vae, target_device=gpu) # Ensure VAE is loaded
505
+ from diffusers_helper.hunyuan import vae_encode
506
+ end_frame_latent = vae_encode(end_frame_pt, vae)
507
+
508
+ # end_frame_output_dimensions_latent is sized like the start_latent and generated latents
509
+ end_frame_output_dimensions_np = resize_and_center_crop(end_frame_np, width, height)
510
+ end_frame_output_dimensions_pt = torch.from_numpy(end_frame_output_dimensions_np).float() / 127.5 - 1
511
+ end_frame_output_dimensions_pt = end_frame_output_dimensions_pt.permute(2, 0, 1)[None, :, None] # VAE expects [B, C, F, H, W]
512
+ end_frame_output_dimensions_latent = vae_encode(end_frame_output_dimensions_pt, vae)
513
+
514
+ print("End frame VAE encoded.")
515
+
516
+ # Video Mode CLIP Vision encoding for end frame
517
+ if model_type == "Video":
518
+ if not high_vram: # Ensure image_encoder is on GPU for this operation
519
+ load_model_as_complete(image_encoder, target_device=gpu)
520
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
521
+ end_clip_embedding = hf_clip_vision_encode(end_frame_np, feature_extractor, image_encoder).last_hidden_state
522
+ end_clip_embedding = end_clip_embedding.to(studio_module.current_generator.transformer.dtype)
523
+ # Need that dtype conversion for end_clip_embedding? I don't think so, but it was in the original PR.
524
+
525
+ if not high_vram: # Offload VAE and image_encoder if they were loaded
526
+ offload_model_from_device_for_memory_preservation(vae, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
527
+ offload_model_from_device_for_memory_preservation(image_encoder, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
528
+
529
+ # Dtype
530
+ for prompt_key in encoded_prompts:
531
+ llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[prompt_key]
532
+ llama_vec = llama_vec.to(studio_module.current_generator.transformer.dtype)
533
+ clip_l_pooler = clip_l_pooler.to(studio_module.current_generator.transformer.dtype)
534
+ encoded_prompts[prompt_key] = (llama_vec, llama_attention_mask, clip_l_pooler)
535
+
536
+ llama_vec_n = llama_vec_n.to(studio_module.current_generator.transformer.dtype)
537
+ clip_l_pooler_n = clip_l_pooler_n.to(studio_module.current_generator.transformer.dtype)
538
+ image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(studio_module.current_generator.transformer.dtype)
539
+
540
+ # Sampling
541
+ stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
542
+
543
+ num_frames = latent_window_size * 4 - 3
544
+
545
+ # Initialize total_generated_latent_frames for Video model
546
+ total_generated_latent_frames = 0 # Default initialization for all model types
547
+
548
+ # Initialize history latents based on model type
549
+ if model_type != "Video" and model_type != "Video F1": # Skip for Video models as we already initialized it
550
+ history_latents = studio_module.current_generator.prepare_history_latents(height, width)
551
+
552
+ # For F1 model, initialize with start latent
553
+ if model_type == "F1":
554
+ history_latents = studio_module.current_generator.initialize_with_start_latent(history_latents, start_latent, has_input_image)
555
+ # If we had a real start image, it was just added to the history_latents
556
+ total_generated_latent_frames = 1 if has_input_image else 0
557
+ elif model_type == "Original" or model_type == "Original with Endframe":
558
+ total_generated_latent_frames = 0
559
+
560
+ history_pixels = None
561
+
562
+ # Get latent paddings from the generator
563
+ latent_paddings = studio_module.current_generator.get_latent_paddings(total_latent_sections)
564
+
565
+ # PROMPT BLENDING: Track section index
566
+ section_idx = 0
567
+
568
+ # Load LoRAs if selected
569
+ if selected_loras:
570
+ lora_folder_from_settings = settings.get("lora_dir")
571
+ studio_module.current_generator.load_loras(selected_loras, lora_folder_from_settings, lora_loaded_names, lora_values)
572
+
573
+ # --- Callback for progress ---
574
+ def callback(d):
575
+ nonlocal last_step_time, step_durations
576
+
577
+ # Check for cancellation signal
578
+ if stream_to_use.input_queue.top() == 'end':
579
+ print("Cancellation signal detected in callback")
580
+ return 'cancel' # Return a signal that will be checked in the sampler
581
+
582
+ now_time = time.time()
583
+ # Record duration between diffusion steps (skip first where duration may include setup)
584
+ if last_step_time is not None:
585
+ step_delta = now_time - last_step_time
586
+ if step_delta > 0:
587
+ step_durations.append(step_delta)
588
+ if len(step_durations) > 30: # Keep only recent 30 steps
589
+ step_durations.pop(0)
590
+ last_step_time = now_time
591
+ avg_step = sum(step_durations) / len(step_durations) if step_durations else 0.0
592
+
593
+ preview = d['denoised']
594
+ from diffusers_helper.hunyuan import vae_decode_fake
595
+ preview = vae_decode_fake(preview)
596
+ preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
597
+ preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
598
+
599
+ # --- Progress & ETA logic ---
600
+ # Current segment progress
601
+ current_step = d['i'] + 1
602
+ percentage = int(100.0 * current_step / steps)
603
+
604
+ # Total progress
605
+ total_steps_done = section_idx * steps + current_step
606
+ total_percentage = int(100.0 * total_steps_done / total_steps)
607
+
608
+ # ETA calculations
609
+ def fmt_eta(sec):
610
+ try:
611
+ return str(datetime.timedelta(seconds=int(sec)))
612
+ except Exception:
613
+ return "--:--"
614
+
615
+ segment_eta = (steps - current_step) * avg_step if avg_step else 0
616
+ total_eta = (total_steps - total_steps_done) * avg_step if avg_step else 0
617
+
618
+ segment_hint = f'Sampling {current_step}/{steps} ETA {fmt_eta(segment_eta)}'
619
+ total_hint = f'Total {total_steps_done}/{total_steps} ETA {fmt_eta(total_eta)}'
620
+
621
+ # For Video model, add the input video frame count when calculating current position
622
+ if model_type == "Video":
623
+ # Calculate the time position including the input video frames
624
+ input_video_time = input_video_frame_count * 4 / 30 # Convert latent frames to time
625
+ current_pos = input_video_time + (total_generated_latent_frames * 4 - 3) / 30
626
+ # Original position is the remaining time to generate
627
+ original_pos = total_second_length - (total_generated_latent_frames * 4 - 3) / 30
628
+ else:
629
+ # For other models, calculate as before
630
+ current_pos = (total_generated_latent_frames * 4 - 3) / 30
631
+ original_pos = total_second_length - current_pos
632
+
633
+ # Ensure positions are not negative
634
+ if current_pos < 0: current_pos = 0
635
+ if original_pos < 0: original_pos = 0
636
+
637
+ hint = segment_hint # deprecated variable kept to minimise other code changes
638
+ desc = studio_module.current_generator.format_position_description(
639
+ total_generated_latent_frames,
640
+ current_pos,
641
+ original_pos,
642
+ current_prompt
643
+ )
644
+
645
+ # Create progress data dictionary
646
+ progress_data = {
647
+ 'preview': preview,
648
+ 'desc': desc,
649
+ 'html': make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint)
650
+ }
651
+
652
+ # Store progress data in the job object if using a job stream
653
+ if job_stream is not None:
654
+ try:
655
+ from __main__ import job_queue
656
+ job = job_queue.get_job(job_id)
657
+ if job:
658
+ job.progress_data = progress_data
659
+ except Exception as e:
660
+ print(f"Error updating job progress data: {e}")
661
+
662
+ # Always push to the job-specific stream
663
+ stream_to_use.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint))))
664
+
665
+ # Always push to the main stream to ensure the UI is updated
666
+ # This is especially important for resumed jobs
667
+ from __main__ import stream as main_stream
668
+ if main_stream: # Always push to main stream regardless of whether it's the same as stream_to_use
669
+ main_stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint))))
670
+
671
+ # Also push job ID to main stream to ensure monitoring connection
672
+ if main_stream:
673
+ main_stream.output_queue.push(('job_id', job_id))
674
+ main_stream.output_queue.push(('monitor_job', job_id))
675
+
676
+ # MagCache / TeaCache Initialization Logic
677
+ magcache = None
678
+ # RT_BORG: I cringe at this, but refactoring to introduce an actual model class will fix it.
679
+ model_family = "F1" if "F1" in model_type else "Original"
680
+
681
+ if settings.get("calibrate_magcache"): # Calibration mode (forces MagCache on)
682
+ print("Setting Up MagCache for Calibration")
683
+ is_calibrating = settings.get("calibrate_magcache")
684
+ studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) # Ensure TeaCache is off
685
+ magcache = MagCache(model_family=model_family, height=height, width=width, num_steps=steps, is_calibrating=is_calibrating, threshold=magcache_threshold, max_consectutive_skips=magcache_max_consecutive_skips, retention_ratio=magcache_retention_ratio)
686
+ studio_module.current_generator.transformer.install_magcache(magcache)
687
+ elif use_magcache: # User selected MagCache
688
+ print("Setting Up MagCache")
689
+ magcache = MagCache(model_family=model_family, height=height, width=width, num_steps=steps, is_calibrating=False, threshold=magcache_threshold, max_consectutive_skips=magcache_max_consecutive_skips, retention_ratio=magcache_retention_ratio)
690
+ studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) # Ensure TeaCache is off
691
+ studio_module.current_generator.transformer.install_magcache(magcache)
692
+ elif use_teacache:
693
+ print("Setting Up TeaCache")
694
+ studio_module.current_generator.transformer.initialize_teacache(enable_teacache=True, num_steps=teacache_num_steps, rel_l1_thresh=teacache_rel_l1_thresh)
695
+ studio_module.current_generator.transformer.uninstall_magcache()
696
+ else:
697
+ print("No Transformer Cache in use")
698
+ studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False)
699
+ studio_module.current_generator.transformer.uninstall_magcache()
700
+
701
+ # --- Main generation loop ---
702
+ # `i_section_loop` will be our loop counter for applying end_frame_latent
703
+ for i_section_loop, latent_padding in enumerate(latent_paddings): # Existing loop structure
704
+ is_last_section = latent_padding == 0
705
+ latent_padding_size = latent_padding * latent_window_size
706
+
707
+ if stream_to_use.input_queue.top() == 'end':
708
+ stream_to_use.output_queue.push(('end', None))
709
+ return
710
+
711
+ # Calculate the current time position
712
+ if model_type == "Video":
713
+ # For Video model, add the input video time to the current position
714
+ input_video_time = input_video_frame_count * 4 / 30 # Convert latent frames to time
715
+ current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds
716
+ if current_time_position < 0:
717
+ current_time_position = 0.01
718
+ else:
719
+ # For other models, calculate as before
720
+ current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds
721
+ if current_time_position < 0:
722
+ current_time_position = 0.01
723
+
724
+ # Find the appropriate prompt for this section
725
+ current_prompt = prompt_sections[0].prompt # Default to first prompt
726
+ for section in prompt_sections:
727
+ if section.start_time <= current_time_position and (section.end_time is None or current_time_position < section.end_time):
728
+ current_prompt = section.prompt
729
+ break
730
+
731
+ # PROMPT BLENDING: Find if we're in a blend window
732
+ blend_alpha = None
733
+ prev_prompt = current_prompt
734
+ next_prompt = current_prompt
735
+
736
+ # Only try to blend if blend_sections > 0 and we have prompt change indices and multiple sections
737
+ try:
738
+ blend_sections_int = int(blend_sections)
739
+ except ValueError:
740
+ blend_sections_int = 0 # Default to 0 if conversion fails, effectively disabling blending
741
+ print(f"Warning: blend_sections ('{blend_sections}') is not a valid integer. Disabling prompt blending for this section.")
742
+ if blend_sections_int > 0 and prompt_change_indices and len(prompt_sections) > 1:
743
+ for i, (change_idx, prompt) in enumerate(prompt_change_indices):
744
+ if section_idx < change_idx:
745
+ prev_prompt = prompt_change_indices[i - 1][1] if i > 0 else prompt
746
+ next_prompt = prompt
747
+ blend_start = change_idx
748
+ blend_end = change_idx + blend_sections
749
+ if section_idx >= change_idx and section_idx < blend_end:
750
+ blend_alpha = (section_idx - change_idx + 1) / blend_sections
751
+ break
752
+ elif section_idx == change_idx:
753
+ # At the exact change, start blending
754
+ if i > 0:
755
+ prev_prompt = prompt_change_indices[i - 1][1]
756
+ next_prompt = prompt
757
+ blend_alpha = 1.0 / blend_sections
758
+ else:
759
+ prev_prompt = prompt
760
+ next_prompt = prompt
761
+ blend_alpha = None
762
+ break
763
+ else:
764
+ # After last change, no blending
765
+ prev_prompt = current_prompt
766
+ next_prompt = current_prompt
767
+ blend_alpha = None
768
+
769
+ # Get the encoded prompt for this section
770
+ if blend_alpha is not None and prev_prompt != next_prompt:
771
+ # Blend embeddings
772
+ prev_llama_vec, prev_llama_attention_mask, prev_clip_l_pooler = encoded_prompts[prev_prompt]
773
+ next_llama_vec, next_llama_attention_mask, next_clip_l_pooler = encoded_prompts[next_prompt]
774
+ llama_vec = (1 - blend_alpha) * prev_llama_vec + blend_alpha * next_llama_vec
775
+ llama_attention_mask = prev_llama_attention_mask # usually same
776
+ clip_l_pooler = (1 - blend_alpha) * prev_clip_l_pooler + blend_alpha * next_clip_l_pooler
777
+ print(f"Blending prompts: '{prev_prompt[:30]}...' -> '{next_prompt[:30]}...', alpha={blend_alpha:.2f}")
778
+ else:
779
+ llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[current_prompt]
780
+
781
+ original_time_position = total_second_length - current_time_position
782
+ if original_time_position < 0:
783
+ original_time_position = 0
784
+
785
+ print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, '
786
+ f'time position: {current_time_position:.2f}s (original: {original_time_position:.2f}s), '
787
+ f'using prompt: {current_prompt[:60]}...')
788
+
789
+ # Apply end_frame_latent to history_latents for models with Endframe support
790
+ if (model_type == "Original with Endframe") and i_section_loop == 0 and end_frame_latent is not None:
791
+ print(f"Applying end_frame_latent to history_latents with strength: {end_frame_strength}")
792
+ actual_end_frame_latent_for_history = end_frame_latent.clone()
793
+ if end_frame_strength != 1.0: # Only multiply if not full strength
794
+ actual_end_frame_latent_for_history = actual_end_frame_latent_for_history * end_frame_strength
795
+
796
+ # Ensure history_latents is on the correct device (usually CPU for this kind of modification if it's init'd there)
797
+ # and that the assigned tensor matches its dtype.
798
+ # The `studio_module.current_generator.prepare_history_latents` initializes it on CPU with float32.
799
+ if history_latents.shape[2] >= 1: # Check if the 'Depth_slots' dimension is sufficient
800
+ if model_type == "Original with Endframe":
801
+ # For Original model, apply to the beginning (position 0)
802
+ history_latents[:, :, 0:1, :, :] = actual_end_frame_latent_for_history.to(
803
+ device=history_latents.device, # Assign to history_latents' current device
804
+ dtype=history_latents.dtype # Match history_latents' dtype
805
+ )
806
+ elif model_type == "F1 with Endframe":
807
+ # For F1 model, apply to the end (last position)
808
+ history_latents[:, :, -1:, :, :] = actual_end_frame_latent_for_history.to(
809
+ device=history_latents.device, # Assign to history_latents' current device
810
+ dtype=history_latents.dtype # Match history_latents' dtype
811
+ )
812
+ print(f"End frame latent applied to history for {model_type} model.")
813
+ else:
814
+ print("Warning: history_latents not shaped as expected for end_frame application.")
815
+
816
+
817
+ # Video models use combined methods to prepare clean latents and indices
818
+ if model_type == "Video":
819
+ # Get num_cleaned_frames from job_params if available, otherwise use default value of 5
820
+ num_cleaned_frames = job_params.get('num_cleaned_frames', 5)
821
+ clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x = \
822
+ studio_module.current_generator.video_prepare_clean_latents_and_indices(end_frame_output_dimensions_latent, end_frame_strength, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames)
823
+ elif model_type == "Video F1":
824
+ # Get num_cleaned_frames from job_params if available, otherwise use default value of 5
825
+ num_cleaned_frames = job_params.get('num_cleaned_frames', 5)
826
+ clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x = \
827
+ studio_module.current_generator.video_f1_prepare_clean_latents_and_indices(latent_window_size, video_latents, history_latents, num_cleaned_frames)
828
+ else:
829
+ # Prepare indices using the generator
830
+ clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices = studio_module.current_generator.prepare_indices(latent_padding_size, latent_window_size)
831
+
832
+ # Prepare clean latents using the generator
833
+ clean_latents, clean_latents_2x, clean_latents_4x = studio_module.current_generator.prepare_clean_latents(start_latent, history_latents)
834
+
835
+ # Print debug info
836
+ print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, latent_padding={latent_padding}")
837
+
838
+ if not high_vram:
839
+ # Unload VAE etc. before loading transformer
840
+ unload_complete_models(vae, text_encoder, text_encoder_2, image_encoder)
841
+ move_model_to_device_with_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation"))
842
+ if selected_loras:
843
+ studio_module.current_generator.move_lora_adapters_to_device(gpu)
844
+
845
+
846
+ from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
847
+ generated_latents = sample_hunyuan(
848
+ transformer=studio_module.current_generator.transformer,
849
+ width=width,
850
+ height=height,
851
+ frames=num_frames,
852
+ real_guidance_scale=cfg,
853
+ distilled_guidance_scale=gs,
854
+ guidance_rescale=rs,
855
+ num_inference_steps=steps,
856
+ generator=random_generator,
857
+ prompt_embeds=llama_vec,
858
+ prompt_embeds_mask=llama_attention_mask,
859
+ prompt_poolers=clip_l_pooler,
860
+ negative_prompt_embeds=llama_vec_n,
861
+ negative_prompt_embeds_mask=llama_attention_mask_n,
862
+ negative_prompt_poolers=clip_l_pooler_n,
863
+ device=gpu,
864
+ dtype=torch.bfloat16,
865
+ image_embeddings=image_encoder_last_hidden_state,
866
+ latent_indices=latent_indices,
867
+ clean_latents=clean_latents,
868
+ clean_latent_indices=clean_latent_indices,
869
+ clean_latents_2x=clean_latents_2x,
870
+ clean_latent_2x_indices=clean_latent_2x_indices,
871
+ clean_latents_4x=clean_latents_4x,
872
+ clean_latent_4x_indices=clean_latent_4x_indices,
873
+ callback=callback,
874
+ )
875
+
876
+ # RT_BORG: Observe the MagCache skip patterns during dev.
877
+ # RT_BORG: We need to use a real logger soon!
878
+ # if magcache is not None and magcache.is_enabled:
879
+ # print(f"MagCache skipped: {len(magcache.steps_skipped_list)} of {steps} steps: {magcache.steps_skipped_list}")
880
+
881
+ total_generated_latent_frames += int(generated_latents.shape[2])
882
+ # Update history latents using the generator
883
+ history_latents = studio_module.current_generator.update_history_latents(history_latents, generated_latents)
884
+
885
+ if not high_vram:
886
+ if selected_loras:
887
+ studio_module.current_generator.move_lora_adapters_to_device(cpu)
888
+ offload_model_from_device_for_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=8)
889
+ load_model_as_complete(vae, target_device=gpu)
890
+
891
+ # Get real history latents using the generator
892
+ real_history_latents = studio_module.current_generator.get_real_history_latents(history_latents, total_generated_latent_frames)
893
+
894
+ if history_pixels is None:
895
+ history_pixels = vae_decode(real_history_latents, vae).cpu()
896
+ else:
897
+ section_latent_frames = studio_module.current_generator.get_section_latent_frames(latent_window_size, is_last_section)
898
+ overlapped_frames = latent_window_size * 4 - 3
899
+
900
+ # Get current pixels using the generator
901
+ current_pixels = studio_module.current_generator.get_current_pixels(real_history_latents, section_latent_frames, vae)
902
+
903
+ # Update history pixels using the generator
904
+ history_pixels = studio_module.current_generator.update_history_pixels(history_pixels, current_pixels, overlapped_frames)
905
+
906
+ print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, history_pixels shape: {history_pixels.shape}")
907
+
908
+ if not high_vram:
909
+ unload_complete_models()
910
+
911
+ output_filename = os.path.join(output_dir, f'{job_id}_{total_generated_latent_frames}.mp4')
912
+ save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=settings.get("mp4_crf"))
913
+ print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}')
914
+ stream_to_use.output_queue.push(('file', output_filename))
915
+
916
+ if is_last_section:
917
+ break
918
+
919
+ section_idx += 1 # PROMPT BLENDING: increment section index
920
+
921
+ # We'll handle combining the videos after the entire generation is complete
922
+ # This section intentionally left empty to remove the in-process combination
923
+ # --- END Main generation loop ---
924
+
925
+ magcache = studio_module.current_generator.transformer.magcache
926
+ if magcache is not None:
927
+ if magcache.is_calibrating:
928
+ output_file = os.path.join(settings.get("output_dir"), "magcache_configuration.txt")
929
+ print(f"MagCache calibration job complete. Appending stats to configuration file: {output_file}")
930
+ magcache.append_calibration_to_file(output_file)
931
+ elif magcache.is_enabled:
932
+ print(f"MagCache ({100.0 * magcache.total_cache_hits / magcache.total_cache_requests:.2f}%) skipped {magcache.total_cache_hits} of {magcache.total_cache_requests} steps.")
933
+ studio_module.current_generator.transformer.uninstall_magcache()
934
+ magcache = None
935
+
936
+ # Handle the results
937
+ result = pipeline.handle_results(job_params, output_filename)
938
+
939
+ # Unload all LoRAs after generation completed
940
+ if selected_loras:
941
+ print("Unloading all LoRAs after generation completed")
942
+ studio_module.current_generator.unload_loras()
943
+ import gc
944
+ gc.collect()
945
+ if torch.cuda.is_available():
946
+ torch.cuda.empty_cache()
947
+
948
+ except Exception as e:
949
+ traceback.print_exc()
950
+ # Unload all LoRAs after error
951
+ if studio_module.current_generator is not None and selected_loras:
952
+ print("Unloading all LoRAs after error")
953
+ studio_module.current_generator.unload_loras()
954
+ import gc
955
+ gc.collect()
956
+ if torch.cuda.is_available():
957
+ torch.cuda.empty_cache()
958
+
959
+ stream_to_use.output_queue.push(('error', f"Error during generation: {traceback.format_exc()}"))
960
+ if not high_vram:
961
+ # Ensure all models including the potentially active transformer are unloaded on error
962
+ unload_complete_models(
963
+ text_encoder, text_encoder_2, image_encoder, vae,
964
+ studio_module.current_generator.transformer if studio_module.current_generator else None
965
+ )
966
+ finally:
967
+ # This finally block is associated with the main try block (starts around line 154)
968
+ if settings.get("clean_up_videos"):
969
+ try:
970
+ video_files = [
971
+ f for f in os.listdir(output_dir)
972
+ if f.startswith(f"{job_id}_") and f.endswith(".mp4")
973
+ ]
974
+ print(f"Video files found for cleanup: {video_files}")
975
+ if video_files:
976
+ def get_frame_count(filename):
977
+ try:
978
+ # Handles filenames like jobid_123.mp4
979
+ return int(filename.replace(f"{job_id}_", "").replace(".mp4", ""))
980
+ except Exception:
981
+ return -1
982
+ video_files_sorted = sorted(video_files, key=get_frame_count)
983
+ print(f"Sorted video files: {video_files_sorted}")
984
+ final_video = video_files_sorted[-1]
985
+ for vf in video_files_sorted[:-1]:
986
+ full_path = os.path.join(output_dir, vf)
987
+ try:
988
+ os.remove(full_path)
989
+ print(f"Deleted intermediate video: {full_path}")
990
+ except Exception as e:
991
+ print(f"Failed to delete {full_path}: {e}")
992
+ except Exception as e:
993
+ print(f"Error during video cleanup: {e}")
994
+
995
+ # Check if the user wants to combine the source video with the generated video
996
+ # This is done after the video cleanup routine to ensure the combined video is not deleted
997
+ # RT_BORG: Retain (but suppress) this original way to combine videos until the new combiner is proven.
998
+ combine_v1 = False
999
+ if combine_v1 and (model_type == "Video" or model_type == "Video F1") and combine_with_source and job_params.get('input_image_path'):
1000
+ print("Creating combined video with source and generated content...")
1001
+ try:
1002
+ input_video_path = job_params.get('input_image_path')
1003
+ if input_video_path and os.path.exists(input_video_path):
1004
+ final_video_path_for_combine = None # Use a different variable name to avoid conflict
1005
+ video_files_for_combine = [
1006
+ f for f in os.listdir(output_dir)
1007
+ if f.startswith(f"{job_id}_") and f.endswith(".mp4") and "combined" not in f
1008
+ ]
1009
+
1010
+ if video_files_for_combine:
1011
+ def get_frame_count_for_combine(filename): # Renamed to avoid conflict
1012
+ try:
1013
+ return int(filename.replace(f"{job_id}_", "").replace(".mp4", ""))
1014
+ except Exception:
1015
+ return float('inf')
1016
+
1017
+ video_files_sorted_for_combine = sorted(video_files_for_combine, key=get_frame_count_for_combine)
1018
+ if video_files_sorted_for_combine: # Check if the list is not empty
1019
+ final_video_path_for_combine = os.path.join(output_dir, video_files_sorted_for_combine[-1])
1020
+
1021
+ if final_video_path_for_combine and os.path.exists(final_video_path_for_combine):
1022
+ combined_output_filename = os.path.join(output_dir, f'{job_id}_combined_v1.mp4')
1023
+ combined_result = None
1024
+ try:
1025
+ if hasattr(studio_module.current_generator, 'combine_videos'):
1026
+ print(f"Using VideoModelGenerator.combine_videos to create side-by-side comparison")
1027
+ combined_result = studio_module.current_generator.combine_videos(
1028
+ source_video_path=input_video_path,
1029
+ generated_video_path=final_video_path_for_combine, # Use the correct variable
1030
+ output_path=combined_output_filename
1031
+ )
1032
+
1033
+ if combined_result:
1034
+ print(f"Combined video saved to: {combined_result}")
1035
+ stream_to_use.output_queue.push(('file', combined_result))
1036
+ else:
1037
+ print("Failed to create combined video, falling back to direct ffmpeg method")
1038
+ combined_result = None
1039
+ else:
1040
+ print("VideoModelGenerator does not have combine_videos method. Using fallback method.")
1041
+ except Exception as e_combine: # Use a different exception variable name
1042
+ print(f"Error in combine_videos method: {e_combine}")
1043
+ print("Falling back to direct ffmpeg method")
1044
+ combined_result = None
1045
+
1046
+ if not combined_result:
1047
+ print("Using fallback method to combine videos")
1048
+ from modules.toolbox.toolbox_processor import VideoProcessor
1049
+ from modules.toolbox.message_manager import MessageManager
1050
+
1051
+ message_manager = MessageManager()
1052
+ # Pass settings.settings if it exists, otherwise pass the settings object
1053
+ video_processor_settings = settings.settings if hasattr(settings, 'settings') else settings
1054
+ video_processor = VideoProcessor(message_manager, video_processor_settings)
1055
+ ffmpeg_exe = video_processor.ffmpeg_exe
1056
+
1057
+ if ffmpeg_exe:
1058
+ print(f"Using ffmpeg at: {ffmpeg_exe}")
1059
+ import subprocess
1060
+ temp_list_file = os.path.join(output_dir, f'{job_id}_filelist.txt')
1061
+ with open(temp_list_file, 'w') as f:
1062
+ f.write(f"file '{input_video_path}'\n")
1063
+ f.write(f"file '{final_video_path_for_combine}'\n") # Use the correct variable
1064
+
1065
+ ffmpeg_cmd = [
1066
+ ffmpeg_exe, "-y", "-f", "concat", "-safe", "0",
1067
+ "-i", temp_list_file, "-c", "copy", combined_output_filename
1068
+ ]
1069
+ print(f"Running ffmpeg command: {' '.join(ffmpeg_cmd)}")
1070
+ subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True)
1071
+ if os.path.exists(temp_list_file):
1072
+ os.remove(temp_list_file)
1073
+ print(f"Combined video saved to: {combined_output_filename}")
1074
+ stream_to_use.output_queue.push(('file', combined_output_filename))
1075
+ else:
1076
+ print("FFmpeg executable not found. Cannot combine videos.")
1077
+ else:
1078
+ print(f"Final video not found for combining with source: {final_video_path_for_combine}")
1079
+ else:
1080
+ print(f"Input video path not found: {input_video_path}")
1081
+ except Exception as e_combine_outer: # Use a different exception variable name
1082
+ print(f"Error combining videos: {e_combine_outer}")
1083
+ traceback.print_exc()
1084
+
1085
+ # Combine input frames (resized and center cropped if needed) with final generated history_pixels tensor sequentially ---
1086
+ # This creates ID_combined.mp4
1087
+ # RT_BORG: Be sure to add this check if we decide to retain the processed input frames for "small" input videos
1088
+ # and job_params.get('input_frames_resized_np') is not None
1089
+ if (model_type == "Video" or model_type == "Video F1") and combine_with_source and history_pixels is not None:
1090
+ print(f"Creating combined video ({job_id}_combined.mp4) with processed input frames and generated history_pixels tensor...")
1091
+ try:
1092
+ # input_frames_resized_np = job_params.get('input_frames_resized_np')
1093
+
1094
+ # RT_BORG: I cringe calliing methods on BaseModelGenerator that only exist on VideoBaseGenerator, until we refactor
1095
+ input_frames_resized_np, fps, target_height, target_width = studio_module.current_generator.extract_video_frames(
1096
+ is_for_encode=False,
1097
+ video_path=job_params['input_image'],
1098
+ resolution=job_params['resolutionW'],
1099
+ no_resize=False,
1100
+ input_files_dir=job_params['input_files_dir']
1101
+ )
1102
+
1103
+ # history_pixels is (B, C, T, H, W), float32, [-1,1], on CPU
1104
+ if input_frames_resized_np is not None and history_pixels.numel() > 0 : # Check if history_pixels is not empty
1105
+ combined_sequential_output_filename = os.path.join(output_dir, f'{job_id}_combined.mp4')
1106
+
1107
+ # fps variable should be from the video_encode call earlier.
1108
+ input_video_fps_for_combine = fps
1109
+ current_crf = settings.get("mp4_crf", 16)
1110
+
1111
+ # Call the new function from video_tools.py
1112
+ combined_sequential_result_path = combine_videos_sequentially_from_tensors(
1113
+ processed_input_frames_np=input_frames_resized_np,
1114
+ generated_frames_pt=history_pixels,
1115
+ output_path=combined_sequential_output_filename,
1116
+ target_fps=input_video_fps_for_combine,
1117
+ crf_value=current_crf
1118
+ )
1119
+ if combined_sequential_result_path:
1120
+ stream_to_use.output_queue.push(('file', combined_sequential_result_path))
1121
+ except Exception as e:
1122
+ print(f"Error creating combined video ({job_id}_combined.mp4): {e}")
1123
+ traceback.print_exc()
1124
+
1125
+ # Final verification of LoRA state
1126
+ if studio_module.current_generator and studio_module.current_generator.transformer:
1127
+ # Verify LoRA state
1128
+ has_loras = False
1129
+ if hasattr(studio_module.current_generator.transformer, 'peft_config'):
1130
+ adapter_names = list(studio_module.current_generator.transformer.peft_config.keys()) if studio_module.current_generator.transformer.peft_config else []
1131
+ if adapter_names:
1132
+ has_loras = True
1133
+ print(f"Transformer has LoRAs: {', '.join(adapter_names)}")
1134
+ else:
1135
+ print(f"Transformer has no LoRAs in peft_config")
1136
+ else:
1137
+ print(f"Transformer has no peft_config attribute")
1138
+
1139
+ # Check for any LoRA modules
1140
+ for name, module in studio_module.current_generator.transformer.named_modules():
1141
+ if hasattr(module, 'lora_A') and module.lora_A:
1142
+ has_loras = True
1143
+ if hasattr(module, 'lora_B') and module.lora_B:
1144
+ has_loras = True
1145
+
1146
+ if not has_loras:
1147
+ print(f"No LoRA components found in transformer")
1148
+
1149
+ stream_to_use.output_queue.push(('end', None))
1150
+ return
modules/prompt_handler.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional
4
+
5
+
6
+ @dataclass
7
+ class PromptSection:
8
+ """Represents a section of the prompt with specific timing information"""
9
+ prompt: str
10
+ start_time: float = 0 # in seconds
11
+ end_time: Optional[float] = None # in seconds, None means until the end
12
+
13
+
14
+ def snap_to_section_boundaries(prompt_sections: List[PromptSection], latent_window_size: int, fps: int = 30) -> List[PromptSection]:
15
+ """
16
+ Adjust timestamps to align with model's internal section boundaries
17
+
18
+ Args:
19
+ prompt_sections: List of PromptSection objects
20
+ latent_window_size: Size of the latent window used in the model
21
+ fps: Frames per second (default: 30)
22
+
23
+ Returns:
24
+ List of PromptSection objects with aligned timestamps
25
+ """
26
+ section_duration = (latent_window_size * 4 - 3) / fps # Duration of one section in seconds
27
+
28
+ aligned_sections = []
29
+ for section in prompt_sections:
30
+ # Snap start time to nearest section boundary
31
+ aligned_start = round(section.start_time / section_duration) * section_duration
32
+
33
+ # Snap end time to nearest section boundary
34
+ aligned_end = None
35
+ if section.end_time is not None:
36
+ aligned_end = round(section.end_time / section_duration) * section_duration
37
+
38
+ # Ensure minimum section length
39
+ if aligned_end is not None and aligned_end <= aligned_start:
40
+ aligned_end = aligned_start + section_duration
41
+
42
+ aligned_sections.append(PromptSection(
43
+ prompt=section.prompt,
44
+ start_time=aligned_start,
45
+ end_time=aligned_end
46
+ ))
47
+
48
+ return aligned_sections
49
+
50
+
51
+ def parse_timestamped_prompt(prompt_text: str, total_duration: float, latent_window_size: int = 9, generation_type: str = "Original") -> List[PromptSection]:
52
+ """
53
+ Parse a prompt with timestamps in the format [0s-2s: text] or [3s: text]
54
+
55
+ Args:
56
+ prompt_text: The input prompt text with optional timestamp sections
57
+ total_duration: Total duration of the video in seconds
58
+ latent_window_size: Size of the latent window used in the model
59
+ generation_type: Type of generation ("Original" or "F1")
60
+
61
+ Returns:
62
+ List of PromptSection objects with timestamps aligned to section boundaries
63
+ and reversed to account for reverse generation (only for Original type)
64
+ """
65
+ # Default prompt for the entire duration if no timestamps are found
66
+ if "[" not in prompt_text or "]" not in prompt_text:
67
+ return [PromptSection(prompt=prompt_text.strip())]
68
+
69
+ sections = []
70
+ # Find all timestamp sections [time: text]
71
+ timestamp_pattern = r'\[(\d+(?:\.\d+)?s)(?:-(\d+(?:\.\d+)?s))?\s*:\s*(.*?)\]'
72
+ regular_text = prompt_text
73
+
74
+ for match in re.finditer(timestamp_pattern, prompt_text):
75
+ start_time_str = match.group(1)
76
+ end_time_str = match.group(2)
77
+ section_text = match.group(3).strip()
78
+
79
+ # Convert time strings to seconds
80
+ start_time = float(start_time_str.rstrip('s'))
81
+ end_time = float(end_time_str.rstrip('s')) if end_time_str else None
82
+
83
+ sections.append(PromptSection(
84
+ prompt=section_text,
85
+ start_time=start_time,
86
+ end_time=end_time
87
+ ))
88
+
89
+ # Remove the processed section from regular_text
90
+ regular_text = regular_text.replace(match.group(0), "")
91
+
92
+ # If there's any text outside of timestamp sections, use it as a default for the entire duration
93
+ regular_text = regular_text.strip()
94
+ if regular_text:
95
+ sections.append(PromptSection(
96
+ prompt=regular_text,
97
+ start_time=0,
98
+ end_time=None
99
+ ))
100
+
101
+ # Sort sections by start time
102
+ sections.sort(key=lambda x: x.start_time)
103
+
104
+ # Fill in end times if not specified
105
+ for i in range(len(sections) - 1):
106
+ if sections[i].end_time is None:
107
+ sections[i].end_time = sections[i+1].start_time
108
+
109
+ # Set the last section's end time to the total duration if not specified
110
+ if sections and sections[-1].end_time is None:
111
+ sections[-1].end_time = total_duration
112
+
113
+ # Snap timestamps to section boundaries
114
+ sections = snap_to_section_boundaries(sections, latent_window_size)
115
+
116
+ # Only reverse timestamps for Original generation type
117
+ if generation_type in ("Original", "Original with Endframe", "Video"):
118
+ # Now reverse the timestamps to account for reverse generation
119
+ reversed_sections = []
120
+ for section in sections:
121
+ reversed_start = total_duration - section.end_time if section.end_time is not None else 0
122
+ reversed_end = total_duration - section.start_time
123
+ reversed_sections.append(PromptSection(
124
+ prompt=section.prompt,
125
+ start_time=reversed_start,
126
+ end_time=reversed_end
127
+ ))
128
+
129
+ # Sort the reversed sections by start time
130
+ reversed_sections.sort(key=lambda x: x.start_time)
131
+ return reversed_sections
132
+
133
+ return sections
134
+
135
+
136
+ def get_section_boundaries(latent_window_size: int = 9, count: int = 10) -> str:
137
+ """
138
+ Calculate and format section boundaries for UI display
139
+
140
+ Args:
141
+ latent_window_size: Size of the latent window used in the model
142
+ count: Number of boundaries to display
143
+
144
+ Returns:
145
+ Formatted string of section boundaries
146
+ """
147
+ section_duration = (latent_window_size * 4 - 3) / 30
148
+ return ", ".join([f"{i*section_duration:.1f}s" for i in range(count)])
149
+
150
+
151
+ def get_quick_prompts() -> List[List[str]]:
152
+ """
153
+ Get a list of example timestamped prompts
154
+
155
+ Returns:
156
+ List of example prompts formatted for Gradio Dataset
157
+ """
158
+ prompts = [
159
+ '[0s: The person waves hello] [2s: The person jumps up and down] [4s: The person does a spin]',
160
+ '[0s: The person raises both arms slowly] [2s: The person claps hands enthusiastically]',
161
+ '[0s: Person gives thumbs up] [1.1s: Person smiles and winks] [2.2s: Person shows two thumbs down]',
162
+ '[0s: Person looks surprised] [1.1s: Person raises arms above head] [2.2s-3.3s: Person puts hands on hips]'
163
+ ]
164
+ return [[x] for x in prompts]
modules/settings.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, Any, Optional
4
+ import os
5
+
6
+ class Settings:
7
+ def __init__(self):
8
+ # Get the project root directory (where settings.py is located)
9
+ project_root = Path(__file__).parent.parent
10
+
11
+ self.settings_file = project_root / ".framepack" / "settings.json"
12
+ self.settings_file.parent.mkdir(parents=True, exist_ok=True)
13
+
14
+ # Set default paths relative to project root
15
+ self.default_settings = {
16
+ "save_metadata": True,
17
+ "gpu_memory_preservation": 6,
18
+ "output_dir": str(project_root / "outputs"),
19
+ "metadata_dir": str(project_root / "outputs"),
20
+ "lora_dir": str(project_root / "loras"),
21
+ "gradio_temp_dir": str(project_root / "temp"),
22
+ "input_files_dir": str(project_root / "input_files"), # New setting for input files
23
+ "auto_save_settings": True,
24
+ "gradio_theme": "default",
25
+ "mp4_crf": 16,
26
+ "clean_up_videos": True,
27
+ "override_system_prompt": False,
28
+ "auto_cleanup_on_startup": False, # ADDED: New setting for startup cleanup
29
+ "latents_display_top": False, # NEW: Control latents preview position (False = right column, True = top of interface)
30
+ "system_prompt_template": "{\"template\": \"<|start_header_id|>system<|end_header_id|>\\n\\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|>\", \"crop_start\": 95}",
31
+ "startup_model_type": "None",
32
+ "startup_preset_name": None,
33
+ "enhancer_prompt_template": """You are a creative assistant for a text-to-video generator. Your task is to take a user's prompt and make it more descriptive, vivid, and detailed. Focus on visual elements. Do not change the core action, but embellish it.
34
+
35
+ User prompt: "{text_to_enhance}"
36
+
37
+ Enhanced prompt:"""
38
+ }
39
+ self.settings = self.load_settings()
40
+
41
+ def load_settings(self) -> Dict[str, Any]:
42
+ """Load settings from file or return defaults"""
43
+ if self.settings_file.exists():
44
+ try:
45
+ with open(self.settings_file, 'r') as f:
46
+ loaded_settings = json.load(f)
47
+ # Merge with defaults to ensure all settings exist
48
+ settings = self.default_settings.copy()
49
+ settings.update(loaded_settings)
50
+ return settings
51
+ except Exception as e:
52
+ print(f"Error loading settings: {e}")
53
+ return self.default_settings.copy()
54
+ return self.default_settings.copy()
55
+
56
+ def save_settings(self, **kwargs):
57
+ """Save settings to file. Accepts keyword arguments for any settings to update."""
58
+ # Update self.settings with any provided keyword arguments
59
+ self.settings.update(kwargs)
60
+ # Ensure all default fields are present
61
+ for k, v in self.default_settings.items():
62
+ self.settings.setdefault(k, v)
63
+
64
+ # Ensure directories exist for relevant fields
65
+ for dir_key in ["output_dir", "metadata_dir", "lora_dir", "gradio_temp_dir"]:
66
+ dir_path = self.settings.get(dir_key)
67
+ if dir_path:
68
+ os.makedirs(dir_path, exist_ok=True)
69
+
70
+ # Save to file
71
+ with open(self.settings_file, 'w') as f:
72
+ json.dump(self.settings, f, indent=4)
73
+
74
+ def get(self, key: str, default: Any = None) -> Any:
75
+ """Get a setting value"""
76
+ return self.settings.get(key, default)
77
+
78
+ def set(self, key: str, value: Any) -> None:
79
+ """Set a setting value"""
80
+ self.settings[key] = value
81
+ if self.settings.get("auto_save_settings", True):
82
+ self.save_settings()
83
+
84
+ def update(self, settings: Dict[str, Any]) -> None:
85
+ """Update multiple settings at once"""
86
+ self.settings.update(settings)
87
+ if self.settings.get("auto_save_settings", True):
88
+ self.save_settings()
modules/toolbox/RIFE/IFNet_HDv3.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .warplayer import warp
5
+
6
+
7
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
8
+ return nn.Sequential(
9
+ nn.Conv2d(
10
+ in_planes,
11
+ out_planes,
12
+ kernel_size=kernel_size,
13
+ stride=stride,
14
+ padding=padding,
15
+ dilation=dilation,
16
+ bias=True,
17
+ ),
18
+ nn.PReLU(out_planes),
19
+ )
20
+
21
+
22
+ def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
23
+ return nn.Sequential(
24
+ nn.Conv2d(
25
+ in_planes,
26
+ out_planes,
27
+ kernel_size=kernel_size,
28
+ stride=stride,
29
+ padding=padding,
30
+ dilation=dilation,
31
+ bias=False,
32
+ ),
33
+ nn.BatchNorm2d(out_planes),
34
+ nn.PReLU(out_planes),
35
+ )
36
+
37
+
38
+ class IFBlock(nn.Module):
39
+ def __init__(self, in_planes, c=64):
40
+ super(IFBlock, self).__init__()
41
+ self.conv0 = nn.Sequential(
42
+ conv(in_planes, c // 2, 3, 2, 1),
43
+ conv(c // 2, c, 3, 2, 1),
44
+ )
45
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
46
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
47
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
48
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
49
+ self.conv1 = nn.Sequential(
50
+ nn.ConvTranspose2d(c, c // 2, 4, 2, 1),
51
+ nn.PReLU(c // 2),
52
+ nn.ConvTranspose2d(c // 2, 4, 4, 2, 1),
53
+ )
54
+ self.conv2 = nn.Sequential(
55
+ nn.ConvTranspose2d(c, c // 2, 4, 2, 1),
56
+ nn.PReLU(c // 2),
57
+ nn.ConvTranspose2d(c // 2, 1, 4, 2, 1),
58
+ )
59
+
60
+ def forward(self, x, flow, scale=1):
61
+ x = F.interpolate(
62
+ x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
63
+ )
64
+ flow = (
65
+ F.interpolate(
66
+ flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
67
+ )
68
+ * 1.0
69
+ / scale
70
+ )
71
+ feat = self.conv0(torch.cat((x, flow), 1))
72
+ feat = self.convblock0(feat) + feat
73
+ feat = self.convblock1(feat) + feat
74
+ feat = self.convblock2(feat) + feat
75
+ feat = self.convblock3(feat) + feat
76
+ flow = self.conv1(feat)
77
+ mask = self.conv2(feat)
78
+ flow = (
79
+ F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
80
+ * scale
81
+ )
82
+ mask = F.interpolate(
83
+ mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False
84
+ )
85
+ return flow, mask
86
+
87
+
88
+ class IFNet(nn.Module):
89
+ def __init__(self):
90
+ super(IFNet, self).__init__()
91
+ self.block0 = IFBlock(7 + 4, c=90)
92
+ self.block1 = IFBlock(7 + 4, c=90)
93
+ self.block2 = IFBlock(7 + 4, c=90)
94
+ self.block_tea = IFBlock(10 + 4, c=90)
95
+ # self.contextnet = Contextnet()
96
+ # self.unet = Unet()
97
+
98
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
99
+ if training == False:
100
+ channel = x.shape[1] // 2
101
+ img0 = x[:, :channel]
102
+ img1 = x[:, channel:]
103
+ flow_list = []
104
+ merged = []
105
+ mask_list = []
106
+ warped_img0 = img0
107
+ warped_img1 = img1
108
+ flow = (x[:, :4]).detach() * 0
109
+ mask = (x[:, :1]).detach() * 0
110
+ loss_cons = 0
111
+ block = [self.block0, self.block1, self.block2]
112
+ for i in range(3):
113
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
114
+ f1, m1 = block[i](
115
+ torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1),
116
+ torch.cat((flow[:, 2:4], flow[:, :2]), 1),
117
+ scale=scale_list[i],
118
+ )
119
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
120
+ mask = mask + (m0 + (-m1)) / 2
121
+ mask_list.append(mask)
122
+ flow_list.append(flow)
123
+ warped_img0 = warp(img0, flow[:, :2])
124
+ warped_img1 = warp(img1, flow[:, 2:4])
125
+ merged.append((warped_img0, warped_img1))
126
+ """
127
+ c0 = self.contextnet(img0, flow[:, :2])
128
+ c1 = self.contextnet(img1, flow[:, 2:4])
129
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
130
+ res = tmp[:, 1:4] * 2 - 1
131
+ """
132
+ for i in range(3):
133
+ mask_list[i] = torch.sigmoid(mask_list[i])
134
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
135
+ # merged[i] = torch.clamp(merged[i] + res, 0, 1)
136
+ return flow_list, mask_list[2], merged
modules/toolbox/RIFE/RIFE_HDv3.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.optim import AdamW
6
+ import numpy as np
7
+ import itertools
8
+ from .warplayer import warp
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+ from .IFNet_HDv3 import *
11
+ from .loss import *
12
+ import devicetorch
13
+ device = devicetorch.get(torch)
14
+
15
+
16
+ class Model:
17
+ def __init__(self, local_rank=-1):
18
+ self.flownet = IFNet()
19
+ self.device()
20
+ self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
21
+ self.epe = EPE()
22
+ # self.vgg = VGGPerceptualLoss().to(device)
23
+ self.sobel = SOBEL()
24
+ if local_rank != -1:
25
+ self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
26
+
27
+ def train(self):
28
+ self.flownet.train()
29
+
30
+ def eval(self):
31
+ self.flownet.eval()
32
+
33
+ def device(self):
34
+ self.flownet.to(device)
35
+
36
+ def load_model(self, path, rank=0):
37
+ def convert(param):
38
+ if rank == -1:
39
+ return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
40
+ else:
41
+ return param
42
+
43
+ if rank <= 0:
44
+ model_path = "{}/flownet.pkl".format(path)
45
+ # Check PyTorch version to safely use weights_only
46
+ from packaging import version
47
+ use_weights_only = version.parse(torch.__version__) >= version.parse("1.13")
48
+
49
+ load_kwargs = {}
50
+ if not torch.cuda.is_available():
51
+ load_kwargs['map_location'] = "cpu"
52
+
53
+ if use_weights_only:
54
+ # For modern PyTorch, be explicit and safe
55
+ load_kwargs['weights_only'] = True
56
+ # print(f"PyTorch >= 1.13 detected. Loading RIFE model with weights_only=True.")
57
+ state_dict = torch.load(model_path, **load_kwargs)
58
+ else:
59
+ # For older PyTorch, load the old way
60
+ print(f"PyTorch < 1.13 detected. Loading RIFE model using legacy method.")
61
+ state_dict = torch.load(model_path, **load_kwargs)
62
+
63
+ self.flownet.load_state_dict(convert(state_dict))
64
+
65
+ def inference(self, img0, img1, scale=1.0):
66
+ imgs = torch.cat((img0, img1), 1)
67
+ scale_list = [4 / scale, 2 / scale, 1 / scale]
68
+ flow, mask, merged = self.flownet(imgs, scale_list)
69
+ return merged[2]
70
+
71
+ def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
72
+ for param_group in self.optimG.param_groups:
73
+ param_group["lr"] = learning_rate
74
+ img0 = imgs[:, :3]
75
+ img1 = imgs[:, 3:]
76
+ if training:
77
+ self.train()
78
+ else:
79
+ self.eval()
80
+ scale = [4, 2, 1]
81
+ flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
82
+ loss_l1 = (merged[2] - gt).abs().mean()
83
+ loss_smooth = self.sobel(flow[2], flow[2] * 0).mean()
84
+ # loss_vgg = self.vgg(merged[2], gt)
85
+ if training:
86
+ self.optimG.zero_grad()
87
+ loss_G = loss_cons + loss_smooth * 0.1
88
+ loss_G.backward()
89
+ self.optimG.step()
90
+ else:
91
+ flow_teacher = flow[2]
92
+ return merged[2], {
93
+ "mask": mask,
94
+ "flow": flow[2][:, :2],
95
+ "loss_l1": loss_l1,
96
+ "loss_cons": loss_cons,
97
+ "loss_smooth": loss_smooth,
98
+ }
modules/toolbox/RIFE/__int__.py ADDED
File without changes