diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..fbfa58a349f300afb0da142df7d0e887d2d3a825 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +hf_download/ +.framepack/ +loras/ +outputs/ +modules/toolbox/model_esrgan/ +modules/toolbox/model_rife/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..7cd5d57f289941481e1bfc7f6e58a36c9a565355 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,28 @@ +ARG CUDA_VERSION=12.4.1 + +FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu22.04 + +ARG CUDA_VERSION + +RUN apt-get update && apt-get install -y \ + python3 python3-pip git ffmpeg wget curl && \ + pip3 install --upgrade pip + +WORKDIR /app + +# This allows caching pip install if only code has changed +COPY requirements.txt . + +# Install dependencies +RUN pip3 install --no-cache-dir -r requirements.txt +RUN export CUDA_SHORT_VERSION=$(echo "${CUDA_VERSION}" | sed 's/\.//g' | cut -c 1-3) && \ + pip3 install --force-reinstall --no-cache-dir torch torchvision torchaudio --index-url "https://download.pytorch.org/whl/cu${CUDA_SHORT_VERSION}" + +# Copy the source code to /app +COPY . . + +VOLUME [ "/app/.framepack", "/app/outputs", "/app/loras", "/app/hf_download", "/app/modules/toolbox/model_esrgan", "/app/modules/toolbox/model_rife" ] + +EXPOSE 7860 + +CMD ["python3", "studio.py"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md new file mode 100644 index 0000000000000000000000000000000000000000..1bdaf0fe1c3cbc0f5ee345400f7046e734a7717f --- /dev/null +++ b/ORIGINAL_README.md @@ -0,0 +1,83 @@ +

FramePack Studio

+ +[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/MtuM7gFJ3V)[![Patreon](https://img.shields.io/badge/Patreon-F96854?style=for-the-badge&logo=patreon&logoColor=white)](https://www.patreon.com/ColinU) + +[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/colinurbs/FramePack-Studio) + +FramePack Studio is an AI video generation application based on FramePack that strives to provide everything you need to create high quality video projects. + +![screencapture-127-0-0-1-7860-2025-06-12-19_50_37](https://github.com/user-attachments/assets/b86a8422-f4ce-452b-80eb-2ba91945f2ea) +![screencapture-127-0-0-1-7860-2025-06-12-19_52_33](https://github.com/user-attachments/assets/ebfb31ca-85b7-4354-87c6-aaab6d1c77b1) + +## Current Features + +- **F1, Original and Video Extension Generations**: Run all in a single queue +- **End Frame Control for 'Original' Model**: Provides greater control over generations +- **Upscaling and Post-processing** +- **Timestamped Prompts**: Define different prompts for specific time segments in your video +- **Prompt Blending**: Define the blending time between timestamped prompts +- **LoRA Support**: Works with most (all?) Hunyuan Video LoRAs +- **Queue System**: Process multiple generation jobs without blocking the interface. Import and export queues. +- **Metadata Saving/Import**: Prompt and seed are encoded into the output PNG, all other generation metadata is saved in a JSON file that can be imported later for similar generations. +- **Custom Presets**: Allow quick switching between named groups of parameters. A custom Startup Preset can also be set. +- **I2V and T2V**: Works with or without an input image to allow for more flexibility when working with standard Hunyuan Video LoRAs +- **Latent Image Options**: When using T2V you can generate based on a black, white, green screen, or pure noise image + +## Prerequisites + +- CUDA-compatible GPU with at least 8GB VRAM (16GB+ recommended) +- 16GB System Memory (32GB+ strongly recommended) +- 80GB+ of storage (including ~25GB for each model family: Original and F1) + +## Documentation + +For information on installation, configuration, and usage, please visit our [documentation site](https://docs.framepackstudio.com/). + +## Installation + +Please see [this guide](https://docs.framepackstudio.com/docs/get_started/) on our documentation site to get FP-Studio installed. + +## LoRAs + +Add LoRAs to the /loras/ folder at the root of the installation. Select the LoRAs you wish to load and set the weights for each generation. Most Hunyuan LoRAs were originally trained for T2V, it's often helpful to run a T2V generation to ensure they're working before using input images. + +NOTE: Slow lora loading is a known issue + +## Working with Timestamped Prompts + +You can create videos with changing prompts over time using the following syntax: + +``` +[0s: A serene forest with sunlight filtering through the trees ] +[5s: A deer appears in the clearing ] +[10s: The deer drinks from a small stream ] +``` + +Each timestamp defines when that prompt should start influencing the generation. The system will (hopefully) smoothly transition between prompts for a cohesive video. + +## Credits + +Many thanks to [Lvmin Zhang](https://github.com/lllyasviel) for the absolutely amazing work on the original [FramePack](https://github.com/lllyasviel/FramePack) code! + +Thanks to [Rickard Edén](https://github.com/neph1) for the LoRA code and their general contributions to this growing FramePack scene! + +Thanks to [Zehong Ma](https://github.com/Zehong-Ma) for [MagCache](https://github.com/Zehong-Ma/MagCache): Fast Video Generation with Magnitude-Aware Cache! + +Thanks to everyone who has joined the Discord, reported a bug, sumbitted a PR, or helped with testing! + + @article{zhang2025framepack, + title={Packing Input Frame Contexts in Next-Frame Prediction Models for Video Generation}, + author={Lvmin Zhang and Maneesh Agrawala}, + journal={Arxiv}, + year={2025} + } + + @misc{zhang2025packinginputframecontext, + title={Packing Input Frame Context in Next-Frame Prediction Models for Video Generation}, + author={Lvmin Zhang and Maneesh Agrawala}, + year={2025}, + eprint={2504.12626}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2504.12626} + } diff --git a/diffusers_helper/bucket_tools.py b/diffusers_helper/bucket_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d2596c0dba56b3bf0cf9c80c419aa23a71e0be --- /dev/null +++ b/diffusers_helper/bucket_tools.py @@ -0,0 +1,97 @@ +bucket_options = { + 128: [ + (96, 160), + (112, 144), + (128, 128), + (144, 112), + (160, 96), + ], + 256: [ + (192, 320), + (224, 288), + (256, 256), + (288, 224), + (320, 192), + ], + 384: [ + (256, 512), + (320, 448), + (384, 384), + (448, 320), + (512, 256), + ], + 512: [ + (352, 704), + (384, 640), + (448, 576), + (512, 512), + (576, 448), + (640, 384), + (704, 352), + ], + 640: [ + (416, 960), + (448, 864), + (480, 832), + (512, 768), + (544, 704), + (576, 672), + (608, 640), + (640, 640), + (640, 608), + (672, 576), + (704, 544), + (768, 512), + (832, 480), + (864, 448), + (960, 416), + ], + 768: [ + (512, 1024), + (576, 896), + (640, 832), + (704, 768), + (768, 768), + (768, 704), + (832, 640), + (896, 576), + (1024, 512), + ], +} + + +def find_nearest_bucket(h, w, resolution=640): + # Use the provided resolution or find the closest available bucket size + # print(f"find_nearest_bucket called with h={h}, w={w}, resolution={resolution}") + + # Convert resolution to int if it's not already + resolution = int(resolution) if not isinstance(resolution, int) else resolution + + if resolution not in bucket_options: + # Find the closest available resolution + available_resolutions = list(bucket_options.keys()) + closest_resolution = min(available_resolutions, key=lambda x: abs(x - resolution)) + # print(f"Resolution {resolution} not found in bucket options, using closest available: {closest_resolution}") + resolution = closest_resolution + # else: + # print(f"Resolution {resolution} found in bucket options") + + # Calculate the aspect ratio of the input image + input_aspect_ratio = w / h if h > 0 else 1.0 + # print(f"Input aspect ratio: {input_aspect_ratio:.4f}") + + min_diff = float('inf') + best_bucket = None + + # Find the bucket size with the closest aspect ratio to the input image + for (bucket_h, bucket_w) in bucket_options[resolution]: + bucket_aspect_ratio = bucket_w / bucket_h if bucket_h > 0 else 1.0 + # Calculate the difference in aspect ratios + diff = abs(bucket_aspect_ratio - input_aspect_ratio) + if diff < min_diff: + min_diff = diff + best_bucket = (bucket_h, bucket_w) + # print(f" Checking bucket ({bucket_h}, {bucket_w}), aspect ratio={bucket_aspect_ratio:.4f}, diff={diff:.4f}, current best={best_bucket}") + + # print(f"Using resolution {resolution}, selected bucket: {best_bucket}") + return best_bucket diff --git a/diffusers_helper/clip_vision.py b/diffusers_helper/clip_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf40dbf1b4ef975640e0ad0d5a7792652d79334 --- /dev/null +++ b/diffusers_helper/clip_vision.py @@ -0,0 +1,12 @@ +import numpy as np + + +def hf_clip_vision_encode(image, feature_extractor, image_encoder): + assert isinstance(image, np.ndarray) + assert image.ndim == 3 and image.shape[2] == 3 + assert image.dtype == np.uint8 + + preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(device=image_encoder.device, dtype=image_encoder.dtype) + image_encoder_output = image_encoder(**preprocessed) + + return image_encoder_output diff --git a/diffusers_helper/dit_common.py b/diffusers_helper/dit_common.py new file mode 100644 index 0000000000000000000000000000000000000000..f02e7b012bff0b3b0fce9136d29fee4a1d49e45e --- /dev/null +++ b/diffusers_helper/dit_common.py @@ -0,0 +1,53 @@ +import torch +import accelerate.accelerator + +from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous + + +accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x + + +def LayerNorm_forward(self, x): + return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x) + + +LayerNorm.forward = LayerNorm_forward +torch.nn.LayerNorm.forward = LayerNorm_forward + + +def FP32LayerNorm_forward(self, x): + origin_dtype = x.dtype + return torch.nn.functional.layer_norm( + x.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +FP32LayerNorm.forward = FP32LayerNorm_forward + + +def RMSNorm_forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is None: + return hidden_states.to(input_dtype) + + return hidden_states.to(input_dtype) * self.weight.to(input_dtype) + + +RMSNorm.forward = RMSNorm_forward + + +def AdaLayerNormContinuous_forward(self, x, conditioning_embedding): + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = emb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward diff --git a/diffusers_helper/gradio/progress_bar.py b/diffusers_helper/gradio/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc612163a171cef37d67d991d729ba9fec066db --- /dev/null +++ b/diffusers_helper/gradio/progress_bar.py @@ -0,0 +1,86 @@ +progress_html = ''' +
+
+
+ +
+ *text* +
+''' + +css = ''' +.loader-container { + display: flex; /* Use flex to align items horizontally */ + align-items: center; /* Center items vertically within the container */ + white-space: nowrap; /* Prevent line breaks within the container */ +} + +.loader { + border: 8px solid #f3f3f3; /* Light grey */ + border-top: 8px solid #3498db; /* Blue */ + border-radius: 50%; + width: 30px; + height: 30px; + animation: spin 2s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* Style the progress bar */ +progress { + appearance: none; /* Remove default styling */ + height: 20px; /* Set the height of the progress bar */ + border-radius: 5px; /* Round the corners of the progress bar */ + background-color: #f3f3f3; /* Light grey background */ + width: 100%; + vertical-align: middle !important; +} + +/* Style the progress bar container */ +.progress-container { + margin-left: 20px; + margin-right: 20px; + flex-grow: 1; /* Allow the progress container to take up remaining space */ +} + +/* Set the color of the progress bar fill */ +progress::-webkit-progress-value { + background-color: #3498db; /* Blue color for the fill */ +} + +progress::-moz-progress-bar { + background-color: #3498db; /* Blue color for the fill in Firefox */ +} + +/* Style the text on the progress bar */ +progress::after { + content: attr(value '%'); /* Display the progress value followed by '%' */ + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: white; /* Set text color */ + font-size: 14px; /* Set font size */ +} + +/* Style other texts */ +.loader-container > span { + margin-left: 5px; /* Add spacing between the progress bar and the text */ +} + +.no-generating-animation > .generating { + display: none !important; +} + +''' + + +def make_progress_bar_html(number, text): + return progress_html.replace('*number*', str(number)).replace('*text*', text) + + +def make_progress_bar_css(): + return css diff --git a/diffusers_helper/hf_login.py b/diffusers_helper/hf_login.py new file mode 100644 index 0000000000000000000000000000000000000000..b039db24378b0419e69ee97042f88e96460766ef --- /dev/null +++ b/diffusers_helper/hf_login.py @@ -0,0 +1,21 @@ +import os + + +def login(token): + from huggingface_hub import login + import time + + while True: + try: + login(token) + print('HF login ok.') + break + except Exception as e: + print(f'HF login failed: {e}. Retrying') + time.sleep(0.5) + + +hf_token = os.environ.get('HF_TOKEN', None) + +if hf_token is not None: + login(hf_token) diff --git a/diffusers_helper/hunyuan.py b/diffusers_helper/hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb1ee90a9c4b96da5f9d615ed7b763f1873fd16 --- /dev/null +++ b/diffusers_helper/hunyuan.py @@ -0,0 +1,163 @@ +import torch + +from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE +from diffusers_helper.utils import crop_or_pad_yield_mask + + +@torch.no_grad() +def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256): + assert isinstance(prompt, str) + + prompt = [prompt] + + # LLAMA + + # Check if there's a custom system prompt template in settings + custom_template = None + try: + from modules.settings import Settings + settings = Settings() + override_system_prompt = settings.get("override_system_prompt", False) + custom_template_str = settings.get("system_prompt_template") + + if override_system_prompt and custom_template_str: + try: + # Convert the string representation to a dictionary + # Extract template and crop_start directly from the string using regex + import re + + # Try to extract the template value + template_match = re.search(r"['\"]template['\"]\s*:\s*['\"](.+?)['\"](?=\s*,|\s*})", custom_template_str, re.DOTALL) + crop_start_match = re.search(r"['\"]crop_start['\"]\s*:\s*(\d+)", custom_template_str) + + if template_match and crop_start_match: + template_value = template_match.group(1) + crop_start_value = int(crop_start_match.group(1)) + + # Unescape any escaped characters in the template + template_value = template_value.replace("\\n", "\n").replace("\\\"", "\"").replace("\\'", "'") + + custom_template = { + "template": template_value, + "crop_start": crop_start_value + } + print(f"Using custom system prompt template from settings: {custom_template}") + else: + print(f"Could not extract template or crop_start from system prompt template string") + print(f"Falling back to default template") + custom_template = None + except Exception as e: + print(f"Error parsing custom system prompt template: {e}") + print(f"Falling back to default template") + custom_template = None + else: + if not override_system_prompt: + print(f"Override system prompt is disabled, using default template") + elif not custom_template_str: + print(f"No custom system prompt template found in settings") + custom_template = None + except Exception as e: + print(f"Error loading settings: {e}") + print(f"Falling back to default template") + custom_template = None + + # Use custom template if available, otherwise use default + template = custom_template if custom_template else DEFAULT_PROMPT_TEMPLATE + + prompt_llama = [template["template"].format(p) for p in prompt] + crop_start = template["crop_start"] + + llama_inputs = tokenizer( + prompt_llama, + padding="max_length", + max_length=max_length + crop_start, + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + + llama_input_ids = llama_inputs.input_ids.to(text_encoder.device) + llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device) + llama_attention_length = int(llama_attention_mask.sum()) + + llama_outputs = text_encoder( + input_ids=llama_input_ids, + attention_mask=llama_attention_mask, + output_hidden_states=True, + ) + + llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length] + # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:] + llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length] + + assert torch.all(llama_attention_mask.bool()) + + # CLIP + + clip_l_input_ids = tokenizer_2( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ).input_ids + clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output + + return llama_vec, clip_l_pooler + + +@torch.no_grad() +def vae_decode_fake(latents): + latent_rgb_factors = [ + [-0.0395, -0.0331, 0.0445], + [0.0696, 0.0795, 0.0518], + [0.0135, -0.0945, -0.0282], + [0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], + [-0.0804, -0.0254, -0.0639], + [-0.0991, 0.0271, -0.0669], + [-0.0646, -0.0422, -0.0400], + [-0.0696, -0.0595, -0.0894], + [-0.0799, -0.0208, -0.0375], + [0.1166, 0.1627, 0.0962], + [0.1165, 0.0432, 0.0407], + [-0.2315, -0.1920, -0.1355], + [-0.0270, 0.0401, -0.0821], + [-0.0616, -0.0997, -0.0727], + [0.0249, -0.0469, -0.1703] + ] # From comfyui + + latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + + weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None] + bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype) + + images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1) + images = images.clamp(0.0, 1.0) + + return images + + +@torch.no_grad() +def vae_decode(latents, vae, image_mode=False): + latents = latents / vae.config.scaling_factor + + if not image_mode: + image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample + else: + latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2) + image = [vae.decode(l.unsqueeze(2)).sample for l in latents] + image = torch.cat(image, dim=2) + + return image + + +@torch.no_grad() +def vae_encode(image, vae): + latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + return latents diff --git a/diffusers_helper/k_diffusion/uni_pc_fm.py b/diffusers_helper/k_diffusion/uni_pc_fm.py new file mode 100644 index 0000000000000000000000000000000000000000..73387c2196dda2f294fa9dec75490ce1ced39c38 --- /dev/null +++ b/diffusers_helper/k_diffusion/uni_pc_fm.py @@ -0,0 +1,144 @@ +# Better Flow Matching UniPC by Lvmin Zhang +# (c) 2025 +# CC BY-SA 4.0 +# Attribution-ShareAlike 4.0 International Licence + + +import torch + +from tqdm.auto import trange + + +def expand_dims(v, dims): + return v[(...,) + (None,) * (dims - 1)] + + +class FlowMatchUniPC: + def __init__(self, model, extra_args, variant='bh1'): + self.model = model + self.variant = variant + self.extra_args = extra_args + + def model_fn(self, x, t): + return self.model(x, t, **self.extra_args) + + def update_fn(self, x, model_prev_list, t_prev_list, t, order): + assert order <= len(model_prev_list) + dims = x.dim() + + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = - torch.log(t_prev_0) + lambda_t = - torch.log(t) + model_prev_0 = model_prev_list[-1] + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = - torch.log(t_prev_i) + rk = ((lambda_prev_i - lambda_prev_0) / h)[0] + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h[0] + h_phi_1 = torch.expm1(hh) + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError('Bad variant!') + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=x.device) + + use_predictor = len(D1s) > 0 + + if use_predictor: + D1s = torch.stack(D1s, dim=1) + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + rhos_p = None + + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0 + + if use_predictor: + pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) + else: + pred_res = 0 + + x_t = x_t_ - expand_dims(B_h, dims) * pred_res + model_t = self.model_fn(x_t, t) + + if D1s is not None: + corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) + else: + corr_res = 0 + + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + + return x_t, model_t + + def sample(self, x, sigmas, callback=None, disable_pbar=False): + order = min(3, len(sigmas) - 2) + model_prev_list, t_prev_list = [], [] + for i in trange(len(sigmas) - 1, disable=disable_pbar): + vec_t = sigmas[i].expand(x.shape[0]) + + if i == 0: + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + elif i < order: + init_order = i + x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + else: + x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + + model_prev_list = model_prev_list[-order:] + t_prev_list = t_prev_list[-order:] + + if callback is not None: + callback_result = callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]}) + if callback_result == 'cancel': + print("Cancellation signal received in sample_unipc, stopping generation") + return model_prev_list[-1] # Return current denoised result + + return model_prev_list[-1] + + +def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'): + assert variant in ['bh1', 'bh2'] + return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable) diff --git a/diffusers_helper/k_diffusion/wrapper.py b/diffusers_helper/k_diffusion/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..cc420da4db1134deca30648077923021b35f82d1 --- /dev/null +++ b/diffusers_helper/k_diffusion/wrapper.py @@ -0,0 +1,51 @@ +import torch + + +def append_dims(x, target_dims): + return x[(...,) + (None,) * (target_dims - x.ndim)] + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0): + if guidance_rescale == 0: + return noise_cfg + + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg + return noise_cfg + + +def fm_wrapper(transformer, t_scale=1000.0): + def k_model(x, sigma, **extra_args): + dtype = extra_args['dtype'] + cfg_scale = extra_args['cfg_scale'] + cfg_rescale = extra_args['cfg_rescale'] + concat_latent = extra_args['concat_latent'] + + original_dtype = x.dtype + sigma = sigma.float() + + x = x.to(dtype) + timestep = (sigma * t_scale).to(dtype) + + if concat_latent is None: + hidden_states = x + else: + hidden_states = torch.cat([x, concat_latent.to(x)], dim=1) + + pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float() + + if cfg_scale == 1.0: + pred_negative = torch.zeros_like(pred_positive) + else: + pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float() + + pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative) + pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale) + + x0 = x.float() - pred.float() * append_dims(sigma, x.ndim) + + return x0.to(dtype=original_dtype) + + return k_model diff --git a/diffusers_helper/lora_utils.py b/diffusers_helper/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d726d53b5c4190876c78e09ac373e1c48a83a65d --- /dev/null +++ b/diffusers_helper/lora_utils.py @@ -0,0 +1,194 @@ +from pathlib import Path, PurePath +from typing import Dict, List, Optional, Union, Tuple +from diffusers.loaders.lora_pipeline import _fetch_state_dict +from diffusers.loaders.lora_conversion_utils import _convert_hunyuan_video_lora_to_diffusers +from diffusers.utils.peft_utils import set_weights_and_activate_adapters +from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING +import torch + +FALLBACK_CLASS_ALIASES = { + "HunyuanVideoTransformer3DModelPacked": "HunyuanVideoTransformer3DModel", +} + +def load_lora(transformer: torch.nn.Module, lora_path: Path, weight_name: str) -> Tuple[torch.nn.Module, str]: + """ + Load LoRA weights into the transformer model. + + Args: + transformer: The transformer model to which LoRA weights will be applied. + lora_path: Path to the folder containing the LoRA weights file. + weight_name: Filename of the weight to load. + + Returns: + A tuple containing the modified transformer and the canonical adapter name. + """ + + state_dict = _fetch_state_dict( + lora_path, + weight_name, + True, + True, + None, + None, + None, + None, + None, + None, + None, + None) + + state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) + + # should weight_name even be Optional[str] or just str? + # For now, we assume it is never None + # The module name in the state_dict must not include a . in the name + # See https://github.com/pytorch/pytorch/pull/6639/files#diff-4be56271f7bfe650e3521c81fd363da58f109cd23ee80d243156d2d6ccda6263R133-R134 + adapter_name = str(PurePath(weight_name).with_suffix('')).replace('.', '_DOT_') + if '_DOT_' in adapter_name: + print( + f"LoRA file '{weight_name}' contains a '.' in the name. " + + 'This may cause issues. Consider renaming the file.' + + f" Using '{adapter_name}' as the adapter name to be safe." + ) + + # Check if adapter already exists and delete it if it does + if hasattr(transformer, 'peft_config') and adapter_name in transformer.peft_config: + print(f"Adapter '{adapter_name}' already exists. Removing it before loading again.") + # Use delete_adapters (plural) instead of delete_adapter + transformer.delete_adapters([adapter_name]) + + # Load the adapter with the original name + transformer.load_lora_adapter(state_dict, network_alphas=None, adapter_name=adapter_name) + print(f"LoRA weights '{adapter_name}' loaded successfully.") + + return transformer, adapter_name + +def unload_all_loras(transformer: torch.nn.Module) -> torch.nn.Module: + """ + Completely unload all LoRA adapters from the transformer model. + + Args: + transformer: The transformer model from which LoRA adapters will be removed. + + Returns: + The transformer model after all LoRA adapters have been removed. + """ + if hasattr(transformer, 'peft_config') and transformer.peft_config: + # Get all adapter names + adapter_names = list(transformer.peft_config.keys()) + + if adapter_names: + print(f"Removing all LoRA adapters: {', '.join(adapter_names)}") + # Delete all adapters + transformer.delete_adapters(adapter_names) + + # Force cleanup of any remaining adapter references + if hasattr(transformer, 'active_adapter'): + transformer.active_adapter = None + + # Clear any cached states + for module in transformer.modules(): + if hasattr(module, 'lora_A'): + if isinstance(module.lora_A, dict): + module.lora_A.clear() + if hasattr(module, 'lora_B'): + if isinstance(module.lora_B, dict): + module.lora_B.clear() + if hasattr(module, 'scaling'): + if isinstance(module.scaling, dict): + module.scaling.clear() + + print("All LoRA adapters have been completely removed.") + else: + print("No LoRA adapters found to remove.") + else: + print("Model doesn't have any LoRA adapters or peft_config.") + + # Force garbage collection + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return transformer + +def resolve_expansion_class_name( + transformer: torch.nn.Module, + fallback_aliases: Dict[str, str], + fn_mapping: Dict[str, callable] +) -> Optional[str]: + """ + Resolves the canonical class name for adapter scale expansion functions, + considering potential fallback aliases. + + Args: + transformer: The transformer model instance. + fallback_aliases: A dictionary mapping model class names to fallback class names. + fn_mapping: A dictionary mapping class names to their respective scale expansion functions. + + Returns: + The resolved class name as a string if a matching scale function is found, + otherwise None. + """ + class_name = transformer.__class__.__name__ + + if class_name in fn_mapping: + return class_name + + fallback_class = fallback_aliases.get(class_name) + if fallback_class in fn_mapping: + print(f"Warning: No scale function for '{class_name}'. Falling back to '{fallback_class}'") + return fallback_class + + return None + +def set_adapters( + transformer: torch.nn.Module, + adapter_names: Union[List[str], str], + weights: Optional[Union[float, List[float]]] = None, +): + """ + Activates and sets the weights for one or more LoRA adapters on the transformer model. + + Args: + transformer: The transformer model to which LoRA adapters are applied. + adapter_names: A single adapter name (str) or a list of adapter names (List[str]) to activate. + weights: Optional. A single float weight or a list of float weights + corresponding to each adapter name. If None, defaults to 1.0 for each adapter. + If a single float, it will be applied to all adapters. + """ + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # Expand a single weight to apply to all adapters if needed + if not isinstance(weights, list): + weights = [weights] * len(adapter_names) + + if len(adapter_names) != len(weights): + raise ValueError( + f"The number of adapter names ({len(adapter_names)}) does not match the number of weights ({len(weights)})." + ) + + # Replace any None weights with a default value of 1.0 + sanitized_weights = [w if w is not None else 1.0 for w in weights] + + resolved_class_name = resolve_expansion_class_name( + transformer, + fallback_aliases=FALLBACK_CLASS_ALIASES, + fn_mapping=_SET_ADAPTER_SCALE_FN_MAPPING + ) + + transformer_class_name = transformer.__class__.__name__ + + if resolved_class_name: + scale_expansion_fn = _SET_ADAPTER_SCALE_FN_MAPPING[resolved_class_name] + print(f"Using scale expansion function for model class '{resolved_class_name}' (original: '{transformer_class_name}')") + final_weights = [ + scale_expansion_fn(transformer, [weight])[0] for weight in sanitized_weights + ] + else: + print(f"Warning: No scale expansion function found for '{transformer_class_name}'. Using raw weights.") + final_weights = sanitized_weights + + set_weights_and_activate_adapters(transformer, adapter_names, final_weights) + + print(f"Adapters {adapter_names} activated with weights {final_weights}.") \ No newline at end of file diff --git a/diffusers_helper/memory.py b/diffusers_helper/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..3380c538a185b0cbd07657ea475d0f5a0aeb17d3 --- /dev/null +++ b/diffusers_helper/memory.py @@ -0,0 +1,134 @@ +# By lllyasviel + + +import torch + + +cpu = torch.device('cpu') +gpu = torch.device(f'cuda:{torch.cuda.current_device()}') +gpu_complete_modules = [] + + +class DynamicSwapInstaller: + @staticmethod + def _install_module(module: torch.nn.Module, **kwargs): + original_class = module.__class__ + module.__dict__['forge_backup_original_class'] = original_class + + def hacked_get_attr(self, name: str): + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in _parameters: + p = _parameters[name] + if p is None: + return None + if p.__class__ == torch.nn.Parameter: + return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad) + else: + return p.to(**kwargs) + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name].to(**kwargs) + return super(original_class, self).__getattr__(name) + + module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { + '__getattr__': hacked_get_attr, + }) + + return + + @staticmethod + def _uninstall_module(module: torch.nn.Module): + if 'forge_backup_original_class' in module.__dict__: + module.__class__ = module.__dict__.pop('forge_backup_original_class') + return + + @staticmethod + def install_model(model: torch.nn.Module, **kwargs): + for m in model.modules(): + DynamicSwapInstaller._install_module(m, **kwargs) + return + + @staticmethod + def uninstall_model(model: torch.nn.Module): + for m in model.modules(): + DynamicSwapInstaller._uninstall_module(m) + return + + +def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device): + if hasattr(model, 'scale_shift_table'): + model.scale_shift_table.data = model.scale_shift_table.data.to(target_device) + return + + for k, p in model.named_modules(): + if hasattr(p, 'weight'): + p.to(target_device) + return + + +def get_cuda_free_memory_gb(device=None): + if device is None: + device = gpu + + memory_stats = torch.cuda.memory_stats(device) + bytes_active = memory_stats['active_bytes.all.current'] + bytes_reserved = memory_stats['reserved_bytes.all.current'] + bytes_free_cuda, _ = torch.cuda.mem_get_info(device) + bytes_inactive_reserved = bytes_reserved - bytes_active + bytes_total_available = bytes_free_cuda + bytes_inactive_reserved + return bytes_total_available / (1024 ** 3) + + +def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0): + print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB') + + for m in model.modules(): + if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb: + torch.cuda.empty_cache() + return + + if hasattr(m, 'weight'): + m.to(device=target_device) + + model.to(device=target_device) + torch.cuda.empty_cache() + return + + +def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0): + print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB') + + for m in model.modules(): + if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb: + torch.cuda.empty_cache() + return + + if hasattr(m, 'weight'): + m.to(device=cpu) + + model.to(device=cpu) + torch.cuda.empty_cache() + return + + +def unload_complete_models(*args): + for m in gpu_complete_modules + list(args): + m.to(device=cpu) + print(f'Unloaded {m.__class__.__name__} as complete.') + + gpu_complete_modules.clear() + torch.cuda.empty_cache() + return + + +def load_model_as_complete(model, target_device, unload=True): + if unload: + unload_complete_models() + + model.to(device=target_device) + print(f'Loaded {model.__class__.__name__} to {target_device} as complete.') + + gpu_complete_modules.append(model) + return diff --git a/diffusers_helper/models/hunyuan_video_packed.py b/diffusers_helper/models/hunyuan_video_packed.py new file mode 100644 index 0000000000000000000000000000000000000000..2969562034b3e0b3311ae41fc1c53d2feb96af3e --- /dev/null +++ b/diffusers_helper/models/hunyuan_video_packed.py @@ -0,0 +1,1062 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import einops +import torch.nn as nn +import numpy as np + +from diffusers.loaders import FromOriginalModelMixin +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import logging +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, PixArtAlphaTextProjection +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers_helper.dit_common import LayerNorm +from diffusers_helper.models.mag_cache import MagCache +from diffusers_helper.utils import zero_module + + +enabled_backends = [] + +if torch.backends.cuda.flash_sdp_enabled(): + enabled_backends.append("flash") +if torch.backends.cuda.math_sdp_enabled(): + enabled_backends.append("math") +if torch.backends.cuda.mem_efficient_sdp_enabled(): + enabled_backends.append("mem_efficient") +if torch.backends.cuda.cudnn_sdp_enabled(): + enabled_backends.append("cudnn") + +print("Currently enabled native sdp backends:", enabled_backends) + +xformers_attn_func = None +flash_attn_varlen_func = None +flash_attn_func = None +sageattn_varlen = None +sageattn = None + +try: + # raise NotImplementedError + from xformers.ops import memory_efficient_attention as xformers_attn_func +except: + pass + +try: + # raise NotImplementedError + from flash_attn import flash_attn_varlen_func, flash_attn_func +except: + pass + +try: + # raise NotImplementedError + from sageattention import sageattn_varlen, sageattn +except: + pass + +# --- Attention Summary --- +print("\n--- Attention Configuration ---") +has_sage = sageattn is not None and sageattn_varlen is not None +has_flash = flash_attn_func is not None and flash_attn_varlen_func is not None +has_xformers = xformers_attn_func is not None + +if has_sage: + print("✅ Using SAGE Attention (highest performance).") + ignored = [] + if has_flash: + ignored.append("Flash Attention") + if has_xformers: + ignored.append("xFormers") + if ignored: + print(f" - Ignoring other installed attention libraries: {', '.join(ignored)}") +elif has_flash: + print("✅ Using Flash Attention (high performance).") + if has_xformers: + print(" - Consider installing SAGE Attention for highest performance.") + print(" - Ignoring other installed attention library: xFormers") +elif has_xformers: + print("✅ Using xFormers.") + print(" - Consider installing SAGE Attention for highest performance.") + print(" - or Consider installing Flash Attention for high performance.") +else: + print("⚠️ No attention library found. Using native PyTorch Scaled Dot Product Attention.") + print(" - For better performance, consider installing one of:") + print(" SAGE Attention (highest performance), Flash Attention (high performance), or xFormers.") +print("-------------------------------\n") + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode='replicate') + + +def center_down_sample_3d(x, kernel_size): + # pt, ph, pw = kernel_size + # cp = (pt * ph * pw) // 2 + # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw) + # xc = xp[cp] + # return xc + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def get_cu_seqlens(text_mask, img_len): + batch_size = text_mask.shape[0] + text_len = text_mask.sum(dim=1) + max_len = text_mask.shape[1] + img_len + + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda") + + for i in range(batch_size): + s = text_len[i] + img_len + s1 = i * max_len + s + s2 = (i + 1) * max_len + cu_seqlens[2 * i + 1] = s1 + cu_seqlens[2 * i + 2] = s2 + + return cu_seqlens + + +def apply_rotary_emb_transposed(x, freqs_cis): + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + out = x.float() * cos + x_rotated.float() * sin + out = out.to(x) + return out + + +def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv): + if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None: + if sageattn is not None: + x = sageattn(q, k, v, tensor_layout='NHD') + return x + + if flash_attn_func is not None: + x = flash_attn_func(q, k, v) + return x + + if xformers_attn_func is not None: + x = xformers_attn_func(q, k, v) + return x + + x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2) + return x + + batch_size = q.shape[0] + q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) + k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) + v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) + if sageattn_varlen is not None: + x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + elif flash_attn_varlen_func is not None: + x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + else: + raise NotImplementedError('No Attn Installed!') + x = x.view(batch_size, max_seqlen_q, *x.shape[2:]) + return x + + +class HunyuanAttnProcessorFlashAttnDouble: + def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = apply_rotary_emb_transposed(query, image_rotary_emb) + key = apply_rotary_emb_transposed(key, image_rotary_emb) + + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) + + hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + hidden_states = hidden_states.flatten(-2) + + txt_length = encoder_hidden_states.shape[1] + hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanAttnProcessorFlashAttnSingle: + def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb): + cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask + + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + txt_length = encoder_hidden_states.shape[1] + + query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1) + key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1) + + hidden_states = attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + hidden_states = hidden_states.flatten(-2) + + hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:] + + return hidden_states, encoder_hidden_states + + +class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, guidance, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning + + +class CombinedTimestepTextProjEmbeddings(nn.Module): + def __init__(self, embedding_dim, pooled_projection_dim): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") + + def forward(self, timestep, pooled_projection): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) + + pooled_projections = self.text_embedder(pooled_projection) + + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=-1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + self_attn_mask[:, :, :, 0] = True + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + + @torch.no_grad() + def get_frequency(self, dim, pos): + T, H, W = pos.shape + freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim)) + freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + def forward_inner(self, frame_indices, height, width, device): + GT, GY, GX = torch.meshgrid( + frame_indices.to(device=device, dtype=torch.float32), + torch.arange(0, height, device=device, dtype=torch.float32), + torch.arange(0, width, device=device, dtype=torch.float32), + indexing="ij" + ) + + FCT, FST = self.get_frequency(self.DT, GT) + FCY, FSY = self.get_frequency(self.DY, GY) + FCX, FSX = self.get_frequency(self.DX, GX) + + result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0) + + return result.to(device) + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + frame_indices = frame_indices.unbind(0) + results = [self.forward_inner(f, height, width, device) for f in frame_indices] + results = torch.stack(results, dim=0) + return results + + +class AdaLayerNormZero(nn.Module): + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormZeroSingle(nn.Module): + def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True): + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1) + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + emb = emb.unsqueeze(-2) + emb = self.linear(self.silu(emb)) + scale, shift = emb.chunk(2, dim=-1) + x = self.norm(x) * (1 + scale) + shift + return x + + +class HunyuanVideoSingleTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + mlp_dim = int(hidden_size * mlp_ratio) + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + bias=True, + processor=HunyuanAttnProcessorFlashAttnSingle(), + qk_norm=qk_norm, + eps=1e-6, + pre_only=True, + ) + + self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm") + self.proj_mlp = nn.Linear(hidden_size, mlp_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + residual = hidden_states + + # 1. Input normalization + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) + + norm_hidden_states, norm_encoder_hidden_states = ( + norm_hidden_states[:, :-text_seq_length, :], + norm_hidden_states[:, -text_seq_length:, :], + ) + + # 2. Attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = torch.cat([attn_output, context_attn_output], dim=1) + + # 3. Modulation and residual connection + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = hidden_states + residual + + hidden_states, encoder_hidden_states = ( + hidden_states[:, :-text_seq_length, :], + hidden_states[:, -text_seq_length:, :], + ) + return hidden_states, encoder_hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanAttnProcessorFlashAttnDouble(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(encoder_hidden_states, emb=temb) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output + + return hidden_states, encoder_hidden_states + + +class ClipVisionProjection(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.up = nn.Linear(in_channels, out_channels * 3) + self.down = nn.Linear(out_channels * 3, out_channels) + + def forward(self, x): + projected_x = self.down(nn.functional.silu(self.up(x))) + return projected_x + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__(self, patch_size, in_chans, embed_dim): + super().__init__() + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + +class HunyuanVideoPatchEmbedForCleanLatents(nn.Module): + def __init__(self, inner_dim): + super().__init__() + self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4)) + self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8)) + + @torch.no_grad() + def initialize_weight_from_another_conv3d(self, another_layer): + weight = another_layer.weight.detach().clone() + bias = another_layer.bias.detach().clone() + + sd = { + 'proj.weight': weight.clone(), + 'proj.bias': bias.clone(), + 'proj_2x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=2, hk=2, wk=2) / 8.0, + 'proj_2x.bias': bias.clone(), + 'proj_4x.weight': einops.repeat(weight, 'b c t h w -> b c (t tk) (h hk) (w wk)', tk=4, hk=4, wk=4) / 64.0, + 'proj_4x.bias': bias.clone(), + } + + sd = {k: v.clone() for k, v in sd.items()} + + self.load_state_dict(sd) + return + + +class HunyuanVideoTransformer3DModelPacked(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_single_layers: int = 40, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 2, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + guidance_embeds: bool = True, + text_embed_dim: int = 4096, + pooled_projection_dim: int = 768, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int] = (16, 56, 56), + has_image_proj=False, + image_proj_dim=1152, + has_clean_x_embedder=False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim) + + self.clean_x_embedder = None + self.image_projection = None + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 4. Single stream transformer blocks + self.single_transformer_blocks = nn.ModuleList( + [ + HunyuanVideoSingleTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_single_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.inner_dim = inner_dim + self.use_gradient_checkpointing = False + self.enable_teacache = False + self.magcache: MagCache = None + + if has_image_proj: + self.install_image_projection(image_proj_dim) + + if has_clean_x_embedder: + self.install_clean_x_embedder() + + self.high_quality_fp32_output_for_inference = False + + def install_image_projection(self, in_channels): + self.image_projection = ClipVisionProjection(in_channels=in_channels, out_channels=self.inner_dim) + self.config['has_image_proj'] = True + self.config['image_proj_dim'] = in_channels + + def install_clean_x_embedder(self): + self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim) + self.config['has_clean_x_embedder'] = True + + def enable_gradient_checkpointing(self): + self.use_gradient_checkpointing = True + print('self.use_gradient_checkpointing = True') + + def disable_gradient_checkpointing(self): + self.use_gradient_checkpointing = False + print('self.use_gradient_checkpointing = False') + + def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15): + self.enable_teacache = enable_teacache + self.cnt = 0 + self.num_steps = num_steps + self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_residual = None + self.teacache_rescale_func = np.poly1d([7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]) + + def install_magcache(self, magcache: MagCache): + self.magcache = magcache + + def uninstall_magcache(self): + self.magcache = None + + def gradient_checkpointing_method(self, block, *args): + if self.use_gradient_checkpointing: + result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False) + else: + result = block(*args) + return result + + def process_input_hidden_states( + self, + latents, latent_indices=None, + clean_latents=None, clean_latent_indices=None, + clean_latents_2x=None, clean_latent_2x_indices=None, + clean_latents_4x=None, clean_latent_4x_indices=None + ): + hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents) + B, C, T, H, W = hidden_states.shape + + if latent_indices is None: + latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device) + rope_freqs = rope_freqs.flatten(2).transpose(1, 2) + + if clean_latents is not None and clean_latent_indices is not None: + clean_latents = clean_latents.to(hidden_states) + clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents) + clean_latents = clean_latents.flatten(2).transpose(1, 2) + + clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device) + clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1) + + if clean_latents_2x is not None and clean_latent_2x_indices is not None: + clean_latents_2x = clean_latents_2x.to(hidden_states) + clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4)) + clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x) + clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2) + + clean_latent_2x_rope_freqs = self.rope(frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device) + clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2)) + clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2)) + clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1) + + if clean_latents_4x is not None and clean_latent_4x_indices is not None: + clean_latents_4x = clean_latents_4x.to(hidden_states) + clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8)) + clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x) + clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2) + + clean_latent_4x_rope_freqs = self.rope(frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device) + clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4)) + clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4)) + clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1) + rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1) + + return hidden_states, rope_freqs + + def forward( + self, + hidden_states, timestep, encoder_hidden_states, encoder_attention_mask, pooled_projections, guidance, + latent_indices=None, + clean_latents=None, clean_latent_indices=None, + clean_latents_2x=None, clean_latent_2x_indices=None, + clean_latents_4x=None, clean_latent_4x_indices=None, + image_embeddings=None, + attention_kwargs=None, return_dict=True + ): + + if attention_kwargs is None: + attention_kwargs = {} + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p, p_t = self.config['patch_size'], self.config['patch_size_t'] + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + original_context_length = post_patch_num_frames * post_patch_height * post_patch_width + + hidden_states, rope_freqs = self.process_input_hidden_states(hidden_states, latent_indices, clean_latents, clean_latent_indices, clean_latents_2x, clean_latent_2x_indices, clean_latents_4x, clean_latent_4x_indices) + + temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections) + encoder_hidden_states = self.gradient_checkpointing_method(self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask) + + if self.image_projection is not None: + assert image_embeddings is not None, 'You must use image embeddings!' + extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings) + extra_attention_mask = torch.ones((batch_size, extra_encoder_hidden_states.shape[1]), dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device) + + # must cat before (not after) encoder_hidden_states, due to attn masking + encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1) + encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1) + + with torch.no_grad(): + if batch_size == 1: + # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want + # If they are not same, then their impls are wrong. Ours are always the correct one. + text_len = encoder_attention_mask.sum().item() + encoder_hidden_states = encoder_hidden_states[:, :text_len] + attention_mask = None, None, None, None + else: + img_seq_len = hidden_states.shape[1] + txt_seq_len = encoder_hidden_states.shape[1] + + cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len) + cu_seqlens_kv = cu_seqlens_q + max_seqlen_q = img_seq_len + txt_seq_len + max_seqlen_kv = max_seqlen_q + + attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + + if self.enable_teacache: + modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0] + + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + curr_rel_l1 = ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item() + self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1) + should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh + + if should_calc: + self.accumulated_rel_l1_distance = 0 + + self.previous_modulated_input = modulated_inp + self.cnt += 1 + + if self.cnt == self.num_steps: + self.cnt = 0 + + if not should_calc: + hidden_states = hidden_states + self.previous_residual + else: + ori_hidden_states = hidden_states.clone() + + hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs) + + self.previous_residual = hidden_states - ori_hidden_states + + elif self.magcache and self.magcache.is_enabled: + if self.magcache.should_skip(hidden_states): + hidden_states = self.magcache.estimate_predicted_hidden_states() + else: + hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs) + self.magcache.update_hidden_states(model_prediction_hidden_states=hidden_states) + + else: + hidden_states, encoder_hidden_states = self._run_denoising_layers(hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs) + + hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb) + + hidden_states = hidden_states[:, -original_context_length:, :] + + if self.high_quality_fp32_output_for_inference: + hidden_states = hidden_states.to(dtype=torch.float32) + if self.proj_out.weight.dtype != torch.float32: + self.proj_out.to(dtype=torch.float32) + + hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states) + + hidden_states = einops.rearrange(hidden_states, 'b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)', + t=post_patch_num_frames, h=post_patch_height, w=post_patch_width, + pt=p_t, ph=p, pw=p) + + if return_dict: + return Transformer2DModelOutput(sample=hidden_states) + + return hidden_states, + + def _run_denoising_layers( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[Tuple], + rope_freqs: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies the dual-stream and single-stream transformer blocks. + """ + for block_id, block in enumerate(self.transformer_blocks): + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + ) + + for block_id, block in enumerate(self.single_transformer_blocks): + hidden_states, encoder_hidden_states = self.gradient_checkpointing_method( + block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs + ) + return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/diffusers_helper/models/mag_cache.py b/diffusers_helper/models/mag_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..281a598582876f46b0c5c158064d6b3fc21248d4 --- /dev/null +++ b/diffusers_helper/models/mag_cache.py @@ -0,0 +1,219 @@ +import numpy as np +import torch +import os + +from diffusers_helper.models.mag_cache_ratios import MAG_RATIOS_DB + + +class MagCache: + """ + Implements the MagCache algorithm for skipping transformer steps during video generation. + MagCache: Fast Video Generation with Magnitude-Aware Cache + Zehong Ma, Longhui Wei, Feng Wang, Shiliang Zhang, Qi Tian + https://arxiv.org/abs/2506.09045 + https://github.com/Zehong-Ma/MagCache + PR Demo defaults were threshold=0.1, max_consectutive_skips=3, retention_ratio=0.2 + Changing defauults to threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25 for quality vs speed tradeoff. + """ + + def __init__(self, model_family, height, width, num_steps, is_enabled=True, is_calibrating = False, threshold=0.1, max_consectutive_skips=2, retention_ratio=0.25): + self.model_family = model_family + self.height = height + self.width = width + self.num_steps = num_steps + + self.is_enabled = is_enabled + self.is_calibrating = is_calibrating + + self.threshold = threshold + self.max_consectutive_skips = max_consectutive_skips + self.retention_ratio = retention_ratio + + # total cache statistics for all sections in the entire generation + self.total_cache_requests = 0 + self.total_cache_hits = 0 + + self.mag_ratios = self._determine_mag_ratios() + + self._init_for_every_section() + + + def _init_for_every_section(self): + self.step_index = 0 + self.steps_skipped_list = [] + #Error accumulation state + self.accumulated_ratio = 1.0 + self.accumulated_steps = 0 + self.accumulated_err = 0 + # Statistics for calibration + self.norm_ratio, self.norm_std, self.cos_dis = [], [], [] + + self.hidden_states = None + self.previous_residual = None + + if self.is_calibrating and self.total_cache_requests > 0: + print('WARNING: Resetting MagCache calibration stats for new section. Typically you only want one section per calibration job. Discarding calibration from previsou section.') + + def should_skip(self, hidden_states): + """ + Expected to be called once per step during the forward pass, for the numer of initialized steps. + Determines if the current step should be skipped based on estimated accumulated error. + If the step is skipped, the hidden_states should be replaced with the output of estimate_predicted_hidden_states(). + + Args: + hidden_states: The current hidden states tensor from the transformer model. + Returns: + True if the step should be skipped, False otherwise + """ + if self.step_index == 0 or self.step_index >= self.num_steps: + self._init_for_every_section() + self.total_cache_requests += 1 + self.hidden_states = hidden_states.clone() # Is clone needed? + + if self.is_calibrating: + print('######################### Calibrating MagCache #########################') + return False + + should_skip_forward = False + if self.step_index>=int(self.retention_ratio*self.num_steps) and self.step_index>=1: # keep first retention_ratio steps + cur_mag_ratio = self.mag_ratios[self.step_index] + self.accumulated_ratio = self.accumulated_ratio*cur_mag_ratio + cur_skip_err = np.abs(1-self.accumulated_ratio) + self.accumulated_err += cur_skip_err + self.accumulated_steps += 1 + # RT_BORG: Per my conversation with Zehong Ma, this 0.06 could potentially be exposed as another tunable param. + if self.accumulated_err<=self.threshold and self.accumulated_steps<=self.max_consectutive_skips and np.abs(1-cur_mag_ratio)<=0.06: + should_skip_forward = True + else: + self.accumulated_ratio = 1.0 + self.accumulated_steps = 0 + self.accumulated_err = 0 + + if should_skip_forward: + self.total_cache_hits += 1 + self.steps_skipped_list.append(self.step_index) + # Increment for next step + self.step_index += 1 + if self.step_index == self.num_steps: + self.step_index = 0 + + return should_skip_forward + + def estimate_predicted_hidden_states(self): + """ + Should be called if and only if should_skip() returned True for the current step. + Estimates the hidden states for the current step based on the previous hidden states and residual. + + Returns: + The estimated hidden states tensor. + """ + return self.hidden_states + self.previous_residual + + def update_hidden_states(self, model_prediction_hidden_states): + """ + If and only if should_skip() returned False for the current step, the denoising layers should have been run, + and this function should be called to compute and store the residual for future steps. + + Args: + model_prediction_hidden_states: The hidden states tensor output from running the denoising layers. + """ + + current_residual = model_prediction_hidden_states - self.hidden_states + if self.is_calibrating: + self._update_calibration_stats(current_residual) + + self.previous_residual = current_residual + + def _update_calibration_stats(self, current_residual): + if self.step_index >= 1: + norm_ratio = ((current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).mean()).item() + norm_std = (current_residual.norm(dim=-1)/self.previous_residual.norm(dim=-1)).std().item() + cos_dis = (1-torch.nn.functional.cosine_similarity(current_residual, self.previous_residual, dim=-1, eps=1e-8)).mean().item() + self.norm_ratio.append(round(norm_ratio, 5)) + self.norm_std.append(round(norm_std, 5)) + self.cos_dis.append(round(cos_dis, 5)) + # print(f"time: {self.step_index}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}") + + self.step_index += 1 + if self.step_index == self.num_steps: + print("norm ratio") + print(self.norm_ratio) + print("norm std") + print(self.norm_std) + print("cos_dis") + print(self.cos_dis) + self.step_index = 0 + + def _determine_mag_ratios(self): + """ + Determines the magnitude ratios by finding the closest resolution and step count + in the pre-calibrated database. + + Returns: + A numpy array of magnitude ratios for the specified configuration, or None if not found. + """ + if self.is_calibrating: + return None + try: + # Find the closest available resolution group for the given model family + resolution_groups = MAG_RATIOS_DB[self.model_family] + available_resolutions = list(resolution_groups.keys()) + if not available_resolutions: + raise ValueError("No resolutions defined for this model family.") + + avg_resolution = (self.height + self.width) / 2.0 + closest_resolution_key = min(available_resolutions, key=lambda r: abs(r - avg_resolution)) + + # Find the closest available step count for the given model/resolution + steps_group = resolution_groups[closest_resolution_key] + available_steps = list(steps_group.keys()) + if not available_steps: + raise ValueError(f"No step counts defined for resolution {closest_resolution_key}.") + closest_steps = min(available_steps, key=lambda x: abs(x - self.num_steps)) + base_ratios = steps_group[closest_steps] + if closest_steps == self.num_steps: + print(f"MagCache: Found ratios for {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {self.num_steps} steps.") + return base_ratios + print(f"MagCache: Using ratios from {self.model_family}, resolution group {closest_resolution_key} ({self.width}x{self.height}), {closest_steps} steps and interpolating to {self.num_steps} steps.") + return self._nearest_step_interpolation(base_ratios, self.num_steps) + except KeyError: + # This will catch if model_family is not in MAG_RATIOS_DB + print(f"Warning: MagCache not calibrated for model family '{self.model_family}'. MagCache will not be used.") + self.is_enabled = False + except (ValueError, TypeError) as e: + # This will catch errors if resolution keys or step keys are not numbers, or if groups are empty. + print(f"Warning: Error processing MagCache DB for model family '{self.model_family}': {e}. MagCache will not be used.") + self.is_enabled = False + return None + + # Nearest interpolation function for MagCache mag_ratios + @staticmethod + def _nearest_step_interpolation(src_array, target_length): + src_length = len(src_array) + if target_length == 1: + return np.array([src_array[-1]]) + + scale = (src_length - 1) / (target_length - 1) + mapped_indices = np.round(np.arange(target_length) * scale).astype(int) + return src_array[mapped_indices] + + def append_calibration_to_file(self, output_file): + """ + Appends tab delimited calibration data (model_family,width,height,norm_ratio) to output_file. + """ + if not self.is_calibrating or not self.norm_ratio: + print("Calibration data can only be appended after calibration.") + return False + try: + with open(output_file, "a") as f: + # Format the data as a string + calibration_set = f"{self.model_family}\t{self.width}\t{self.height}\t{self.num_steps}" + # data_string = f"{calibration_set}\t{self.norm_ratio}" + entry_string = f"{calibration_set}\t{self.num_steps}: np.array([1.0] + {self.norm_ratio})," + # Append the data to the file + f.write(entry_string + "\n") + print(f"Calibration data appended to {output_file}") + return True + except Exception as e: + print(f"Error appending calibration data: {e}") + return False diff --git a/diffusers_helper/models/mag_cache_ratios.py b/diffusers_helper/models/mag_cache_ratios.py new file mode 100644 index 0000000000000000000000000000000000000000..60e0ce88311e425b5f9bb7d34e1d2b52f6565163 --- /dev/null +++ b/diffusers_helper/models/mag_cache_ratios.py @@ -0,0 +1,71 @@ +import numpy as np + +# Pre-calibrated magnitude ratios for different model families, resolutions, and step counts +# Format: MAG_RATIOS_DB[model_family][resolution][step_count] = np.array([...]) +# All calibrations performed with FramePackStudio v0.4 with default settings and seed 31337 +MAG_RATIOS_DB = { + "Original": { + 768: { + 25: np.array([1.0] + [1.30469, 1.22656, 1.03906, 1.02344, 1.03906, 1.01562, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.04688, 0.99219, 1.00781, 1.00781, 0.98828, 0.94141, 0.93359, 0.78906]), + 50: np.array([1.0] + [1.30469, 0.99609, 1.16406, 1.0625, 1.01562, 1.02344, 1.01562, 1.0, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.99219, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.98047, 0.91016, 0.85938, 0.78125]), + 75: np.array([1.0] + [1.01562, 1.27344, 1.0, 1.15625, 1.0625, 1.00781, 1.02344, 1.0, 1.02344, 1.0, 1.02344, 1.0, 1.0, 1.04688, 0.99609, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99609, 1.00781, 0.99219, 0.99609, 1.00781, 1.03125, 0.98438, 1.01562, 1.02344, 0.98828, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98438, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.0, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 0.99609, 0.96484, 0.97266, 0.94531, 0.91406, 0.90234, 0.85938, 0.76172]), + }, + 640: { + 25: np.array([1.0] + [1.30469, 1.22656, 1.05469, 1.02344, 1.03906, 1.02344, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.03906, 0.98828, 1.0, 1.0, 0.98828, 0.94531, 0.93359, 0.78516]), + 50: np.array([1.0] + [1.28906, 1.0, 1.17188, 1.0625, 1.02344, 1.02344, 1.02344, 1.0, 1.04688, 0.99219, 1.00781, 1.01562, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.99219, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.97656, 0.91016, 0.85547, 0.76953]), + 75: np.array([1.0] + [1.00781, 1.30469, 1.0, 1.15625, 1.05469, 1.01562, 1.01562, 1.0, 1.01562, 1.0, 1.02344, 0.99609, 1.0, 1.04688, 0.99219, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99219, 1.00781, 0.99219, 0.99609, 1.00781, 1.03906, 0.98828, 1.01562, 1.02344, 0.98828, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98828, 1.03906, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.00781, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 0.99609, 0.96484, 0.97266, 0.94922, 0.91797, 0.90625, 0.86328, 0.75781]), + }, + 512: { + 25: np.array([1.0] + [1.32031, 1.21875, 1.03906, 1.02344, 1.03906, 1.01562, 1.03906, 1.05469, 1.02344, 1.03906, 0.99609, 1.02344, 1.03125, 1.01562, 1.02344, 1.00781, 1.04688, 0.98828, 1.0, 1.00781, 0.98828, 0.94141, 0.9375, 0.78516]), + 50: np.array([1.0] + [1.32031, 0.99609, 1.15625, 1.0625, 1.01562, 1.02344, 1.02344, 1.0, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.03125, 1.00781, 1.01562, 0.98828, 1.04688, 1.00781, 0.98828, 1.01562, 1.00781, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.01562, 0.97266, 1.01562, 1.01562, 0.98828, 1.01562, 0.98828, 0.98828, 1.0, 0.98438, 0.95703, 0.95312, 0.98047, 0.91016, 0.85938, 0.77734]), + 75: np.array([1.0] + [1.02344, 1.28906, 1.0, 1.15625, 1.0625, 1.01562, 1.01562, 1.0, 1.02344, 1.0, 1.02344, 1.0, 1.0, 1.04688, 0.99609, 1.0, 1.00781, 1.01562, 1.00781, 1.0, 1.03125, 1.02344, 1.00781, 1.02344, 1.00781, 1.00781, 1.00781, 0.99219, 1.05469, 0.99609, 1.00781, 0.99219, 0.99609, 1.00781, 1.03125, 0.98828, 1.01562, 1.02344, 0.99219, 1.03906, 0.98828, 0.99219, 1.02344, 0.99219, 1.00781, 1.01562, 1.01562, 0.98828, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 0.98828, 1.04688, 0.96484, 0.99219, 1.03125, 0.98047, 0.99219, 1.00781, 0.99609, 0.98828, 0.98438, 0.98828, 0.97266, 1.0, 0.96484, 0.97266, 0.94922, 0.91797, 0.90625, 0.86328, 0.75781]), + }, + 384: { + 25: np.array([1.0] + [1.58594, 1.0625, 1.03906, 1.02344, 1.0625, 1.04688, 1.04688, 1.03125, 1.02344, 1.01562, 1.00781, 1.01562, 1.01562, 0.99219, 1.07031, 0.96094, 0.96484, 1.03125, 0.96875, 0.94141, 0.97266, 0.92188, 0.88672, 0.75]), + 50: np.array([1.0] + [1.29688, 1.21875, 1.02344, 1.03906, 1.0, 1.03906, 1.00781, 1.02344, 1.05469, 1.00781, 1.03125, 1.01562, 1.04688, 1.00781, 0.98828, 1.03906, 0.99609, 1.03125, 1.03125, 0.98438, 1.01562, 0.99609, 1.01562, 1.0, 1.01562, 1.0, 1.01562, 0.98047, 1.02344, 1.04688, 0.97266, 0.98828, 0.97656, 0.98828, 1.05469, 0.97656, 0.98828, 0.98047, 0.98438, 0.95703, 1.00781, 0.96484, 0.97656, 0.94531, 0.94141, 0.94531, 0.875, 0.85547, 0.79688]), + 75: np.array([1.0] + [1.29688, 1.14844, 1.07031, 1.01562, 1.02344, 1.02344, 0.99609, 1.04688, 0.99219, 1.00781, 1.00781, 1.00781, 1.03125, 1.02344, 1.00781, 1.03125, 1.00781, 1.00781, 1.04688, 0.99219, 1.00781, 0.99219, 1.00781, 1.03125, 0.98438, 1.01562, 1.02344, 0.98828, 1.01562, 1.01562, 1.0, 1.0, 1.00781, 0.99219, 1.01562, 1.0, 0.99609, 1.03125, 0.98828, 1.02344, 0.99609, 0.97656, 1.01562, 1.00781, 1.03906, 0.96484, 1.03125, 0.96484, 1.02344, 0.98438, 0.96094, 1.03125, 0.98828, 1.01562, 0.97266, 1.02344, 0.97656, 1.0, 0.98438, 0.95703, 1.02344, 0.96094, 0.99609, 0.99609, 0.9375, 0.98438, 0.94141, 0.97266, 0.96875, 0.89844, 0.95703, 0.87109, 0.86328, 0.85547]), + }, + 256: { + 25: np.array([1.0] + [1.59375, 1.10156, 1.08594, 1.05469, 1.03906, 1.03125, 1.03125, 1.02344, 1.01562, 1.02344, 0.98438, 1.0625, 0.96875, 1.00781, 0.98438, 1.00781, 0.92969, 0.97656, 0.99609, 0.91406, 0.94922, 0.88672, 0.86328, 0.75391]), + 50: np.array([1.0] + [1.46875, 1.10156, 1.04688, 1.03906, 1.02344, 1.0625, 1.03125, 1.02344, 1.03906, 1.0, 1.01562, 1.01562, 1.03125, 0.99609, 1.01562, 1.00781, 0.99609, 1.02344, 1.01562, 1.00781, 1.00781, 0.98047, 1.02344, 1.04688, 0.97266, 0.99609, 1.0, 1.00781, 0.98047, 1.00781, 0.98047, 1.02344, 0.96094, 0.96875, 1.03125, 0.94531, 0.98047, 1.01562, 0.96484, 0.94531, 0.99609, 0.95312, 0.96484, 0.91406, 0.92969, 0.92969, 0.88672, 0.85156, 0.89062]), + 75: np.array([1.0] + [1.25781, 1.23438, 1.04688, 1.03906, 1.0, 1.03906, 1.02344, 1.00781, 1.05469, 1.00781, 1.03125, 1.01562, 1.04688, 0.99219, 1.00781, 1.0, 1.03906, 0.99609, 1.03125, 0.98828, 1.00781, 1.00781, 1.02344, 0.99219, 1.00781, 1.00781, 1.0, 0.99609, 1.03125, 0.99609, 1.01562, 0.99609, 0.97656, 1.03906, 0.98438, 1.03906, 0.96484, 1.01562, 0.98438, 1.02344, 0.95312, 1.03906, 0.98047, 1.02344, 0.98047, 1.0, 0.97656, 1.03906, 0.94922, 1.01562, 0.97266, 1.01562, 0.98828, 0.97266, 1.01562, 0.97656, 1.00781, 0.9375, 1.0, 0.97266, 1.00781, 0.96875, 0.97266, 0.96875, 0.93359, 0.98047, 0.95703, 0.94922, 0.94922, 0.92578, 0.94531, 0.85938, 0.91797, 0.94531]), + }, + 128: { + 25: np.array([1.0] + [1.63281, 1.0625, 1.14062, 1.04688, 1.03906, 1.03125, 1.03125, 1.02344, 0.99219, 1.03125, 0.96484, 1.02344, 0.97266, 0.98438, 0.97656, 0.96875, 0.95312, 0.95312, 0.97656, 0.92188, 0.9375, 0.87891, 0.85156, 0.82812]), + 50: np.array([1.0] + [1.5625, 1.05469, 1.03906, 1.03125, 1.07031, 1.03906, 1.05469, 0.99609, 1.03906, 1.0, 1.05469, 0.97656, 1.01562, 1.01562, 1.0, 1.03125, 1.01562, 0.97656, 0.99609, 1.03906, 0.98828, 0.98047, 1.03125, 0.99609, 0.97266, 1.0, 0.99609, 0.98438, 0.97266, 1.00781, 0.98828, 0.98438, 0.97656, 0.98047, 0.99609, 0.95703, 1.00781, 0.96484, 0.99219, 0.92969, 0.98828, 0.94922, 0.94531, 0.92969, 0.92578, 0.92188, 0.91797, 0.90625, 1.00781]), + 75: np.array([1.0] + [1.47656, 1.08594, 1.02344, 1.0625, 1.0, 1.02344, 1.0625, 1.03906, 1.01562, 1.0, 1.04688, 1.0, 1.01562, 1.00781, 1.01562, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.01562, 1.01562, 0.98438, 1.03125, 0.99609, 1.01562, 0.98828, 1.01562, 1.00781, 1.01562, 0.97656, 1.00781, 0.98438, 1.03906, 0.97656, 1.01562, 0.94531, 1.03125, 1.00781, 0.98438, 1.02344, 0.97656, 1.00781, 0.95703, 1.01562, 0.97656, 1.02344, 0.97266, 0.96484, 1.02344, 0.96094, 0.99609, 0.99609, 0.96094, 1.0, 1.00781, 0.97266, 0.98828, 0.96875, 0.96484, 0.98828, 0.95703, 0.99219, 0.97266, 0.89844, 1.0, 0.96094, 0.92578, 0.95703, 0.9375, 0.91016, 0.97266, 0.96875, 1.07812]), + }, + }, + "F1": { + 768: { + 25: np.array([1.0] + [1.27344, 1.08594, 1.03125, 1.00781, 1.00781, 1.00781, 1.03125, 1.03906, 1.00781, 1.03125, 0.98828, 1.01562, 1.00781, 1.01562, 1.00781, 0.98438, 1.04688, 0.98438, 0.96875, 1.03125, 0.97266, 0.92188, 0.95703, 0.77734]), + 50: np.array([1.0] + [1.27344, 1.0, 1.07031, 1.01562, 1.0, 1.02344, 1.0, 1.01562, 1.02344, 0.98828, 1.00781, 1.00781, 1.0, 1.02344, 1.00781, 1.03125, 1.0, 1.00781, 0.97656, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.00781, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.95312, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.84375, 0.76562]), + 75: np.array([1.0] + [1.0, 1.26562, 1.00781, 1.07812, 1.0, 1.00781, 1.01562, 1.0, 1.00781, 1.0, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.03125, 1.01562, 1.00781, 1.02344, 1.0, 0.99219, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98828, 1.01562, 1.03125, 0.97266, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96875, 1.0625, 0.98828, 1.00781, 0.99609, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.96094, 1.00781, 0.96875, 1.01562, 0.98828, 0.99609, 0.95703, 0.96875, 1.02344, 0.96875, 0.96484, 0.95312, 0.89844, 0.90234, 0.86719, 0.76562]), + }, + 640: { + 25: np.array([1.0] + [1.27344, 1.07031, 1.03906, 1.00781, 1.00781, 1.00781, 1.03125, 1.04688, 1.00781, 1.03125, 0.99219, 1.01562, 1.01562, 1.01562, 1.00781, 0.98438, 1.05469, 0.98438, 0.96875, 1.03125, 0.97266, 0.92578, 0.95703, 0.77734]), + 50: np.array([1.0] + [1.27344, 1.0, 1.07812, 1.01562, 1.0, 1.01562, 1.00781, 1.00781, 1.02344, 0.98828, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.03125, 1.0, 1.00781, 0.98047, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.01562, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.95312, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.84375, 0.76953]), + 75: np.array([1.0] + [1.0, 1.27344, 1.0, 1.07031, 1.01562, 0.99609, 1.00781, 1.0, 1.00781, 1.0, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.00781, 1.02344, 1.0, 0.99219, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98438, 1.01562, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96875, 1.0625, 0.98828, 1.00781, 1.0, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.96094, 1.00781, 0.96875, 1.01562, 0.98828, 0.99609, 0.95703, 0.96875, 1.02344, 0.96484, 0.96484, 0.95312, 0.89844, 0.90234, 0.87109, 0.76953]), + }, + 512: { + 25: np.array([1.0] + [1.28125, 1.08594, 1.02344, 1.01562, 1.00781, 1.00781, 1.03125, 1.03906, 1.00781, 1.03125, 0.98828, 1.01562, 1.00781, 1.01562, 1.00781, 0.98438, 1.04688, 0.98438, 0.96875, 1.03125, 0.97656, 0.92188, 0.96094, 0.77734]), + 50: np.array([1.0] + [1.28125, 1.00781, 1.08594, 1.0, 1.01562, 1.01562, 1.00781, 1.00781, 1.02344, 0.98438, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.03125, 1.0, 1.00781, 0.97656, 1.05469, 1.00781, 0.98047, 1.01562, 1.0, 1.00781, 1.00781, 0.99609, 1.02344, 1.00781, 0.99609, 1.00781, 0.98047, 1.04688, 1.00781, 0.95312, 1.03125, 0.99609, 0.97266, 1.02344, 1.00781, 0.94922, 1.02344, 0.98828, 0.93359, 0.97656, 0.97656, 0.92188, 0.83984, 0.76953]), + 75: np.array([1.0] + [1.00781, 1.27344, 1.00781, 1.07812, 1.0, 1.00781, 1.00781, 0.99609, 1.01562, 0.99609, 1.00781, 1.00781, 1.0, 1.02344, 0.98438, 1.0, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.00781, 1.02344, 1.0, 0.98828, 1.01562, 0.98047, 1.05469, 1.0, 1.01562, 0.99219, 0.98438, 1.00781, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.04688, 0.97656, 0.99609, 1.03125, 0.98047, 0.99609, 1.02344, 1.0, 0.96484, 1.0625, 0.98828, 1.0, 0.99609, 1.0, 0.98828, 1.04688, 0.94922, 0.97266, 1.0625, 0.95703, 1.00781, 0.96484, 1.02344, 0.98828, 0.99609, 0.95703, 0.96875, 1.03125, 0.96875, 0.96484, 0.95703, 0.89844, 0.90234, 0.87109, 0.76953]), + }, + 384: { + 25: np.array([1.0] + [1.36719, 1.03125, 1.02344, 1.01562, 1.04688, 1.03125, 1.04688, 1.02344, 1.00781, 0.99609, 1.01562, 0.99219, 1.00781, 0.97266, 1.07812, 0.95703, 0.9375, 1.04688, 0.98828, 0.89844, 1.00781, 0.92188, 0.89844, 0.72656]), + 50: np.array([1.0] + [1.27344, 1.08594, 1.01562, 1.01562, 1.00781, 1.00781, 1.00781, 1.00781, 1.03906, 1.00781, 1.02344, 1.00781, 1.03125, 1.01562, 0.98047, 1.04688, 0.98438, 1.02344, 1.02344, 0.97656, 1.03125, 0.98828, 1.00781, 0.98828, 1.00781, 1.0, 0.99609, 0.97656, 1.03125, 1.04688, 0.97656, 0.98047, 0.97266, 0.96094, 1.09375, 0.95703, 0.98438, 1.0, 0.96094, 0.93359, 1.03125, 0.97266, 1.0, 0.92188, 0.93359, 0.96484, 0.85156, 0.84375, 0.79688]), + 75: np.array([1.0] + [1.26562, 1.08594, 1.00781, 1.00781, 1.01562, 1.00781, 1.00781, 1.03125, 0.98438, 1.00781, 1.00781, 1.00781, 1.02344, 1.01562, 1.00781, 1.02344, 0.98828, 1.01562, 1.03125, 1.0, 1.01562, 0.98047, 1.01562, 1.03125, 0.96875, 1.00781, 1.03125, 0.97266, 1.0, 1.02344, 1.00781, 1.0, 1.00781, 0.97266, 1.01562, 1.00781, 0.97656, 1.0625, 0.97656, 1.02344, 0.98828, 0.96484, 1.02344, 1.01562, 1.04688, 0.95312, 1.03125, 0.98047, 1.01562, 0.97266, 0.94922, 1.0625, 0.96484, 1.02344, 0.98438, 1.02344, 0.98047, 0.97266, 0.99219, 0.92969, 1.07031, 0.96094, 0.98047, 0.98438, 0.94531, 0.98828, 0.9375, 0.97266, 0.98828, 0.86719, 0.98047, 0.84766, 0.86328, 0.86719]), + }, + 256: { + 25: np.array([1.0] + [1.38281, 1.04688, 1.05469, 1.03906, 1.03906, 1.01562, 1.0, 1.03906, 1.00781, 1.03125, 0.96094, 1.08594, 0.96094, 1.00781, 0.98438, 1.02344, 0.91016, 0.99609, 1.0, 0.90234, 0.97266, 0.87109, 0.85547, 0.71875]), + 50: np.array([1.0] + [1.375, 1.02344, 1.02344, 1.01562, 1.00781, 1.04688, 1.03125, 1.00781, 1.03125, 1.00781, 1.0, 1.01562, 1.01562, 0.98438, 1.03125, 1.00781, 0.98438, 1.02344, 1.01562, 1.02344, 0.98047, 0.97656, 1.03125, 1.04688, 0.97656, 0.98047, 0.99219, 1.01562, 0.99609, 0.98828, 0.99219, 1.03125, 0.96875, 0.94141, 1.05469, 0.94531, 0.95312, 1.04688, 0.94141, 0.96094, 1.00781, 0.96094, 0.95312, 0.91016, 0.91797, 0.93359, 0.89062, 0.8125, 0.83984]), + 75: np.array([1.0] + [1.27344, 1.09375, 1.00781, 1.01562, 1.00781, 1.00781, 1.00781, 1.00781, 1.03906, 1.00781, 1.02344, 1.00781, 1.03125, 1.0, 1.00781, 1.0, 1.03125, 0.98828, 1.02344, 0.97266, 1.0, 1.02344, 1.03125, 0.98047, 0.99609, 1.02344, 0.99219, 0.98047, 1.0625, 0.99219, 1.00781, 0.98828, 0.96875, 1.0625, 0.98047, 1.04688, 0.95312, 1.03125, 0.98047, 1.03125, 0.94922, 1.03125, 0.99609, 1.03125, 0.95703, 1.0, 0.98438, 1.03906, 0.9375, 1.01562, 0.96094, 1.03125, 0.98828, 0.98047, 1.00781, 0.99219, 1.0, 0.91016, 1.01562, 0.97266, 1.00781, 0.98047, 0.98438, 0.97266, 0.89062, 1.00781, 0.95703, 0.95312, 0.94141, 0.9375, 0.93359, 0.82031, 0.91016, 0.87891]), + }, + 128: { + 25: np.array([1.0] + [1.42188, 1.03125, 1.0625, 1.04688, 1.02344, 1.03125, 1.03125, 1.02344, 0.99219, 1.02344, 0.94531, 1.04688, 0.94922, 1.01562, 0.98047, 0.96484, 0.93359, 0.96484, 0.98438, 0.91406, 0.97266, 0.87891, 0.85547, 0.83203]), + 50: np.array([1.0] + [1.375, 1.03125, 1.02344, 1.01562, 1.04688, 1.01562, 1.04688, 1.0, 1.04688, 0.98438, 1.05469, 0.97266, 1.01562, 1.01562, 0.98828, 1.03906, 1.01562, 0.97656, 0.98047, 1.03906, 0.97266, 0.97266, 1.0625, 0.99219, 0.97656, 0.97266, 1.02344, 0.99609, 0.94141, 1.03906, 0.97266, 0.99219, 0.96875, 0.96484, 1.00781, 0.95312, 1.03906, 0.94922, 0.99609, 0.91797, 1.00781, 0.96484, 0.9375, 0.93359, 0.91797, 0.92969, 0.91406, 0.90625, 0.92188]), + 75: np.array([1.0] + [1.36719, 1.02344, 1.01562, 1.03906, 0.99219, 1.00781, 1.03906, 1.03125, 0.99219, 0.99609, 1.05469, 0.99609, 1.01562, 1.00781, 1.00781, 1.00781, 1.0, 1.02344, 1.01562, 1.01562, 1.00781, 1.0, 0.96875, 1.05469, 0.99219, 1.00781, 1.0, 1.0, 1.01562, 1.00781, 0.96875, 1.02344, 0.94922, 1.07812, 0.94922, 1.03125, 0.92578, 1.05469, 0.97266, 0.98438, 1.04688, 0.98438, 1.00781, 0.95312, 1.02344, 0.94922, 1.04688, 0.96875, 0.96484, 1.03906, 0.93359, 1.00781, 1.00781, 0.95312, 1.00781, 1.0, 0.97656, 1.0, 0.94922, 0.96094, 1.02344, 0.92969, 1.02344, 0.96094, 0.88281, 1.03125, 0.94141, 0.91797, 0.98438, 0.92578, 0.90234, 0.99219, 0.92188, 0.98438]), + }, + } +} \ No newline at end of file diff --git a/diffusers_helper/pipelines/k_diffusion_hunyuan.py b/diffusers_helper/pipelines/k_diffusion_hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..d72b44b859c0042af1e227612edd76fa85880548 --- /dev/null +++ b/diffusers_helper/pipelines/k_diffusion_hunyuan.py @@ -0,0 +1,120 @@ +import torch +import math + +from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc +from diffusers_helper.k_diffusion.wrapper import fm_wrapper +from diffusers_helper.utils import repeat_to_batch_size + + +def flux_time_shift(t, mu=1.15, sigma=1.0): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0): + k = (y2 - y1) / (x2 - x1) + b = y1 - k * x1 + mu = k * context_length + b + mu = min(mu, math.log(exp_max)) + return mu + + +def get_flux_sigmas_from_mu(n, mu): + sigmas = torch.linspace(1, 0, steps=n + 1) + sigmas = flux_time_shift(sigmas, mu=mu) + return sigmas + + +@torch.inference_mode() +def sample_hunyuan( + transformer, + sampler='unipc', + initial_latent=None, + concat_latent=None, + strength=1.0, + width=512, + height=512, + frames=16, + real_guidance_scale=1.0, + distilled_guidance_scale=6.0, + guidance_rescale=0.0, + shift=None, + num_inference_steps=25, + batch_size=None, + generator=None, + prompt_embeds=None, + prompt_embeds_mask=None, + prompt_poolers=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + negative_prompt_poolers=None, + dtype=torch.bfloat16, + device=None, + negative_kwargs=None, + callback=None, + **kwargs, +): + device = device or transformer.device + + if batch_size is None: + batch_size = int(prompt_embeds.shape[0]) + + latents = torch.randn((batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device).to(device=device, dtype=torch.float32) + + B, C, T, H, W = latents.shape + seq_length = T * H * W // 4 + + if shift is None: + mu = calculate_flux_mu(seq_length, exp_max=7.0) + else: + mu = math.log(shift) + + sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device) + + k_model = fm_wrapper(transformer) + + if initial_latent is not None: + sigmas = sigmas * strength + first_sigma = sigmas[0].to(device=device, dtype=torch.float32) + initial_latent = initial_latent.to(device=device, dtype=torch.float32) + latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma + + if concat_latent is not None: + concat_latent = concat_latent.to(latents) + + distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype) + + prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size) + prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size) + prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size) + negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size) + negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size) + negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size) + concat_latent = repeat_to_batch_size(concat_latent, batch_size) + + sampler_kwargs = dict( + dtype=dtype, + cfg_scale=real_guidance_scale, + cfg_rescale=guidance_rescale, + concat_latent=concat_latent, + positive=dict( + pooled_projections=prompt_poolers, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_embeds_mask, + guidance=distilled_guidance, + **kwargs, + ), + negative=dict( + pooled_projections=negative_prompt_poolers, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_embeds_mask, + guidance=distilled_guidance, + **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}), + ) + ) + + if sampler == 'unipc': + results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback) + else: + raise NotImplementedError(f'Sampler {sampler} is not supported.') + + return results diff --git a/diffusers_helper/thread_utils.py b/diffusers_helper/thread_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..144fdad6a218b10e77944e927ea350bb84b559a1 --- /dev/null +++ b/diffusers_helper/thread_utils.py @@ -0,0 +1,76 @@ +import time + +from threading import Thread, Lock + + +class Listener: + task_queue = [] + lock = Lock() + thread = None + + @classmethod + def _process_tasks(cls): + while True: + task = None + with cls.lock: + if cls.task_queue: + task = cls.task_queue.pop(0) + + if task is None: + time.sleep(0.001) + continue + + func, args, kwargs = task + try: + func(*args, **kwargs) + except Exception as e: + print(f"Error in listener thread: {e}") + + @classmethod + def add_task(cls, func, *args, **kwargs): + with cls.lock: + cls.task_queue.append((func, args, kwargs)) + + if cls.thread is None: + cls.thread = Thread(target=cls._process_tasks, daemon=True) + cls.thread.start() + + +def async_run(func, *args, **kwargs): + Listener.add_task(func, *args, **kwargs) + + +class FIFOQueue: + def __init__(self): + self.queue = [] + self.lock = Lock() + + def push(self, item): + with self.lock: + self.queue.append(item) + + def pop(self): + with self.lock: + if self.queue: + return self.queue.pop(0) + return None + + def top(self): + with self.lock: + if self.queue: + return self.queue[0] + return None + + def next(self): + while True: + with self.lock: + if self.queue: + return self.queue.pop(0) + + time.sleep(0.001) + + +class AsyncStream: + def __init__(self): + self.input_queue = FIFOQueue() + self.output_queue = FIFOQueue() diff --git a/diffusers_helper/utils.py b/diffusers_helper/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8cd7a0c5f8f04960476e893321c52318ea079e14 --- /dev/null +++ b/diffusers_helper/utils.py @@ -0,0 +1,613 @@ +import os +import cv2 +import json +import random +import glob +import torch +import einops +import numpy as np +import datetime +import torchvision + +import safetensors.torch as sf +from PIL import Image + + +def min_resize(x, m): + if x.shape[0] < x.shape[1]: + s0 = m + s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) + else: + s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) + s1 = m + new_max = max(s1, s0) + raw_max = max(x.shape[0], x.shape[1]) + if new_max < raw_max: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (s1, s0), interpolation=interpolation) + return y + + +def d_resize(x, y): + H, W, C = y.shape + new_min = min(H, W) + raw_min = min(x.shape[0], x.shape[1]) + if new_min < raw_min: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (W, H), interpolation=interpolation) + return y + + +def resize_and_center_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + original_width, original_height = pil_image.size + scale_factor = max(target_width / original_width, target_height / original_height) + resized_width = int(round(original_width * scale_factor)) + resized_height = int(round(original_height * scale_factor)) + resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) + left = (resized_width - target_width) / 2 + top = (resized_height - target_height) / 2 + right = (resized_width + target_width) / 2 + bottom = (resized_height + target_height) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + return np.array(cropped_image) + + +def resize_and_center_crop_pytorch(image, target_width, target_height): + B, C, H, W = image.shape + + if H == target_height and W == target_width: + return image + + scale_factor = max(target_width / W, target_height / H) + resized_width = int(round(W * scale_factor)) + resized_height = int(round(H * scale_factor)) + + resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False) + + top = (resized_height - target_height) // 2 + left = (resized_width - target_width) // 2 + cropped = resized[:, :, top:top + target_height, left:left + target_width] + + return cropped + + +def resize_without_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS) + return np.array(resized_image) + + +def just_crop(image, w, h): + if h == image.shape[0] and w == image.shape[1]: + return image + + original_height, original_width = image.shape[:2] + k = min(original_height / h, original_width / w) + new_width = int(round(w * k)) + new_height = int(round(h * k)) + x_start = (original_width - new_width) // 2 + y_start = (original_height - new_height) // 2 + cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width] + return cropped_image + + +def write_to_json(data, file_path): + temp_file_path = file_path + ".tmp" + with open(temp_file_path, 'wt', encoding='utf-8') as temp_file: + json.dump(data, temp_file, indent=4) + os.replace(temp_file_path, file_path) + return + + +def read_from_json(file_path): + with open(file_path, 'rt', encoding='utf-8') as file: + data = json.load(file) + return data + + +def get_active_parameters(m): + return {k: v for k, v in m.named_parameters() if v.requires_grad} + + +def cast_training_params(m, dtype=torch.float32): + result = {} + for n, param in m.named_parameters(): + if param.requires_grad: + param.data = param.to(dtype) + result[n] = param + return result + + +def separate_lora_AB(parameters, B_patterns=None): + parameters_normal = {} + parameters_B = {} + + if B_patterns is None: + B_patterns = ['.lora_B.', '__zero__'] + + for k, v in parameters.items(): + if any(B_pattern in k for B_pattern in B_patterns): + parameters_B[k] = v + else: + parameters_normal[k] = v + + return parameters_normal, parameters_B + + +def set_attr_recursive(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + setattr(obj, attrs[-1], value) + return + + +def print_tensor_list_size(tensors): + total_size = 0 + total_elements = 0 + + if isinstance(tensors, dict): + tensors = tensors.values() + + for tensor in tensors: + total_size += tensor.nelement() * tensor.element_size() + total_elements += tensor.nelement() + + total_size_MB = total_size / (1024 ** 2) + total_elements_B = total_elements / 1e9 + + print(f"Total number of tensors: {len(tensors)}") + print(f"Total size of tensors: {total_size_MB:.2f} MB") + print(f"Total number of parameters: {total_elements_B:.3f} billion") + return + + +@torch.no_grad() +def batch_mixture(a, b=None, probability_a=0.5, mask_a=None): + batch_size = a.size(0) + + if b is None: + b = torch.zeros_like(a) + + if mask_a is None: + mask_a = torch.rand(batch_size) < probability_a + + mask_a = mask_a.to(a.device) + mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1)) + result = torch.where(mask_a, a, b) + return result + + +@torch.no_grad() +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +@torch.no_grad() +def supress_lower_channels(m, k, alpha=0.01): + data = m.weight.data.clone() + + assert int(data.shape[1]) >= k + + data[:, :k] = data[:, :k] * alpha + m.weight.data = data.contiguous().clone() + return m + + +def freeze_module(m): + if not hasattr(m, '_forward_inside_frozen_module'): + m._forward_inside_frozen_module = m.forward + m.requires_grad_(False) + m.forward = torch.no_grad()(m.forward) + return m + + +def get_latest_safetensors(folder_path): + safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors')) + + if not safetensors_files: + raise ValueError('No file to resume!') + + latest_file = max(safetensors_files, key=os.path.getmtime) + latest_file = os.path.abspath(os.path.realpath(latest_file)) + return latest_file + + +def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32): + tags = tags_str.split(', ') + tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags))) + prompt = ', '.join(tags) + return prompt + + +def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0): + numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma) + if round_to_int: + numbers = np.round(numbers).astype(int) + return numbers.tolist() + + +def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False): + edges = np.linspace(0, 1, n + 1) + points = np.random.uniform(edges[:-1], edges[1:]) + numbers = inclusive + (exclusive - inclusive) * points + if round_to_int: + numbers = np.round(numbers).astype(int) + return numbers.tolist() + + +def soft_append_bcthw(history, current, overlap=0): + if overlap <= 0: + return torch.cat([history, current], dim=2) + + assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})" + assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})" + + weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1) + blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap] + output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2) + + return output.to(history) + + +def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0): + b, c, t, h, w = x.shape + + per_row = b + for p in [6, 5, 4, 3, 2]: + if b % p == 0: + per_row = p + break + + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) + torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))}) + return x + + +def save_bcthw_as_png(x, output_filename): + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)') + torchvision.io.write_png(x, output_filename) + return output_filename + + +def save_bchw_as_png(x, output_filename): + os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True) + x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 + x = x.detach().cpu().to(torch.uint8) + x = einops.rearrange(x, 'b c h w -> c h (b w)') + torchvision.io.write_png(x, output_filename) + return output_filename + + +def add_tensors_with_padding(tensor1, tensor2): + if tensor1.shape == tensor2.shape: + return tensor1 + tensor2 + + shape1 = tensor1.shape + shape2 = tensor2.shape + + new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2)) + + padded_tensor1 = torch.zeros(new_shape) + padded_tensor2 = torch.zeros(new_shape) + + padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1 + padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2 + + result = padded_tensor1 + padded_tensor2 + return result + + +def print_free_mem(): + torch.cuda.empty_cache() + free_mem, total_mem = torch.cuda.mem_get_info(0) + free_mem_mb = free_mem / (1024 ** 2) + total_mem_mb = total_mem / (1024 ** 2) + print(f"Free memory: {free_mem_mb:.2f} MB") + print(f"Total memory: {total_mem_mb:.2f} MB") + return + + +def print_gpu_parameters(device, state_dict, log_count=1): + summary = {"device": device, "keys_count": len(state_dict)} + + logged_params = {} + for i, (key, tensor) in enumerate(state_dict.items()): + if i >= log_count: + break + logged_params[key] = tensor.flatten()[:3].tolist() + + summary["params"] = logged_params + + print(str(summary)) + return + + +def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18): + from PIL import Image, ImageDraw, ImageFont + + txt = Image.new("RGB", (width, height), color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype(font_path, size=size) + + if text == '': + return np.array(txt) + + # Split text into lines that fit within the image width + lines = [] + words = text.split() + current_line = words[0] + + for word in words[1:]: + line_with_word = f"{current_line} {word}" + if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width: + current_line = line_with_word + else: + lines.append(current_line) + current_line = word + + lines.append(current_line) + + # Draw the text line by line + y = 0 + line_height = draw.textbbox((0, 0), "A", font=font)[3] + + for line in lines: + if y + line_height > height: + break # stop drawing if the next line will be outside the image + draw.text((0, y), line, fill="black", font=font) + y += line_height + + return np.array(txt) + + +def blue_mark(x): + x = x.copy() + c = x[:, :, 2] + b = cv2.blur(c, (9, 9)) + x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1) + return x + + +def green_mark(x): + x = x.copy() + x[:, :, 2] = -1 + x[:, :, 0] = -1 + return x + + +def frame_mark(x): + x = x.copy() + x[:64] = -1 + x[-64:] = -1 + x[:, :8] = 1 + x[:, -8:] = 1 + return x + + +@torch.inference_mode() +def pytorch2numpy(imgs): + results = [] + for x in imgs: + y = x.movedim(0, -1) + y = y * 127.5 + 127.5 + y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8) + results.append(y) + return results + + +@torch.inference_mode() +def numpy2pytorch(imgs): + h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0 + h = h.movedim(-1, 1) + return h + + +@torch.no_grad() +def duplicate_prefix_to_suffix(x, count, zero_out=False): + if zero_out: + return torch.cat([x, torch.zeros_like(x[:count])], dim=0) + else: + return torch.cat([x, x[:count]], dim=0) + + +def weighted_mse(a, b, weight): + return torch.mean(weight.float() * (a.float() - b.float()) ** 2) + + +def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0): + x = (x - x_min) / (x_max - x_min) + x = max(0.0, min(x, 1.0)) + x = x ** sigma + return y_min + x * (y_max - y_min) + + +def expand_to_dims(x, target_dims): + return x.view(*x.shape, *([1] * max(0, target_dims - x.dim()))) + + +def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int): + if tensor is None: + return None + + first_dim = tensor.shape[0] + + if first_dim == batch_size: + return tensor + + if batch_size % first_dim != 0: + raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.") + + repeat_times = batch_size // first_dim + + return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1)) + + +def dim5(x): + return expand_to_dims(x, 5) + + +def dim4(x): + return expand_to_dims(x, 4) + + +def dim3(x): + return expand_to_dims(x, 3) + + +def crop_or_pad_yield_mask(x, length): + B, F, C = x.shape + device = x.device + dtype = x.dtype + + if F < length: + y = torch.zeros((B, length, C), dtype=dtype, device=device) + mask = torch.zeros((B, length), dtype=torch.bool, device=device) + y[:, :F, :] = x + mask[:, :F] = True + return y, mask + + return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device) + + +def extend_dim(x, dim, minimal_length, zero_pad=False): + original_length = int(x.shape[dim]) + + if original_length >= minimal_length: + return x + + if zero_pad: + padding_shape = list(x.shape) + padding_shape[dim] = minimal_length - original_length + padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device) + else: + idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1) + last_element = x[idx] + padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim) + + return torch.cat([x, padding], dim=dim) + + +def lazy_positional_encoding(t, repeats=None): + if not isinstance(t, list): + t = [t] + + from diffusers.models.embeddings import get_timestep_embedding + + te = torch.tensor(t) + te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0) + + if repeats is None: + return te + + te = te[:, None, :].expand(-1, repeats, -1) + + return te + + +def state_dict_offset_merge(A, B, C=None): + result = {} + keys = A.keys() + + for key in keys: + A_value = A[key] + B_value = B[key].to(A_value) + + if C is None: + result[key] = A_value + B_value + else: + C_value = C[key].to(A_value) + result[key] = A_value + B_value - C_value + + return result + + +def state_dict_weighted_merge(state_dicts, weights): + if len(state_dicts) != len(weights): + raise ValueError("Number of state dictionaries must match number of weights") + + if not state_dicts: + return {} + + total_weight = sum(weights) + + if total_weight == 0: + raise ValueError("Sum of weights cannot be zero") + + normalized_weights = [w / total_weight for w in weights] + + keys = state_dicts[0].keys() + result = {} + + for key in keys: + result[key] = state_dicts[0][key] * normalized_weights[0] + + for i in range(1, len(state_dicts)): + state_dict_value = state_dicts[i][key].to(result[key]) + result[key] += state_dict_value * normalized_weights[i] + + return result + + +def group_files_by_folder(all_files): + grouped_files = {} + + for file in all_files: + folder_name = os.path.basename(os.path.dirname(file)) + if folder_name not in grouped_files: + grouped_files[folder_name] = [] + grouped_files[folder_name].append(file) + + list_of_lists = list(grouped_files.values()) + return list_of_lists + + +def generate_timestamp(): + now = datetime.datetime.now() + timestamp = now.strftime('%y%m%d_%H%M%S') + milliseconds = f"{int(now.microsecond / 1000):03d}" + random_number = random.randint(0, 9999) + return f"{timestamp}_{milliseconds}_{random_number}" + + +def write_PIL_image_with_png_info(image, metadata, path): + from PIL.PngImagePlugin import PngInfo + + png_info = PngInfo() + for key, value in metadata.items(): + png_info.add_text(key, value) + + image.save(path, "PNG", pnginfo=png_info) + return image + + +def torch_safe_save(content, path): + torch.save(content, path + '_tmp') + os.replace(path + '_tmp', path) + return path + + +def move_optimizer_to_device(optimizer, device): + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..5b1af4e5355990f28a64e3b83b3f476dfb6b1517 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,25 @@ +services: + studio: + build: + # modify this if you are building the image locally and need a different CUDA version + args: + - CUDA_VERSION=12.4.1 + # modify the tag here if you need a different CUDA version or branch + image: colinurbs/fp-studio:cuda12.4-latest-develop + restart: unless-stopped + ports: + - "7860:7860" + volumes: + - "./loras:/app/loras" + - "./outputs:/app/outputs" + - "./.framepack:/app/.framepack" + - "./modules/toolbox/model_esrgan:/app/modules/toolbox/model_esrgan" + - "./modules/toolbox/model_rife:/app/modules/toolbox/model_rife" + - "$HOME/.cache/huggingface:/app/hf_download" + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] diff --git a/install.bat b/install.bat new file mode 100644 index 0000000000000000000000000000000000000000..e4936ba72364f2557f4ae0adf7c40537e8de1cff --- /dev/null +++ b/install.bat @@ -0,0 +1,208 @@ +@echo off +echo FramePack-Studio Setup Script +setlocal enabledelayedexpansion + +REM Check if Python is installed (basic check) +where python >nul 2>&1 +if %errorlevel% neq 0 ( + echo Error: Python is not installed or not in your PATH. Please install Python and try again. + goto end +) + +if exist "%cd%/venv" ( +echo Virtual Environment already exists. +set /p choice= "Do you want to reinstall packages?[Y/N]: " + +if "!choice!" == "y" (goto checkgpu) +if "!choice!"=="Y" (goto checkgpu) + +goto end +) + +REM Check the python version +echo Python versions 3.10-3.12 have been confirmed to work. Other versions are currently not supported. You currently have: +python -V +set choice= +set /p choice= "Do you want to continue?[Y/N]: " + + +if "!choice!" == "y" (goto makevenv) +if "!choice!"=="Y" (goto makevenv) + +goto end + +:makevenv +REM This creates a virtual environment in the folder +echo Creating a Virtual Environment... +python -m venv venv +echo Upgrading pip in Virtual Environment to lower chance of error... +"%cd%/venv/Scripts/python.exe" -m pip install --upgrade pip + +:checkgpu +REM ask Windows for GPU +where nvidia-smi >nul 2>&1 +if %errorlevel% neq 0 ( + echo Error: Nvidia GPU doesn't exist or drivers installed incorrectly. Please confirm your drivers are installed. + goto end +) + +echo Checking your GPU... + +for /F "tokens=* skip=1" %%n in ('nvidia-smi --query-gpu=name') do set GPU_NAME=%%n && goto gpuchecked + +:gpuchecked +echo Detected %GPU_NAME% +set "GPU_SERIES=%GPU_NAME:*RTX =%" +set "GPU_SERIES=%GPU_SERIES:~0,2%00" + +REM This gets the shortened Python version for later use. e.g. 3.10.13 becomes 310. +for /f "delims=" %%A in ('python -V') do set "pyv=%%A" +for /f "tokens=2 delims= " %%A in ("%pyv%") do ( + set pyv=%%A +) +set pyv=%pyv:.=% +set pyv=%pyv:~0,3% + +echo Installing torch... + +if !GPU_SERIES! geq 5000 ( + goto torch270 +) else ( + goto torch260 +) + +REM RTX 5000 Series +:torch270 +"%cd%/venv/Scripts/pip.exe" install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 --force-reinstall +REM Check if pip installation was successful +if %errorlevel% neq 0 ( + echo Warning: Failed to install dependencies. You may need to install them manually. + goto end +) + +REM Ask if user wants Sage Attention +set choice= +echo Do you want to install any of the following? They speed up generation. +echo 1) Sage Attention +echo 2) Flash Attention +echo 3) BOTH! +echo 4) No +set /p choice= "Input Selection: " + +set both="N" + +if "!choice!" == "1" (goto triton270) +if "!choice!"== "2" (goto flash270) +if "!choice!"== "3" (set both="Y" +goto triton270 +) + +goto requirements + +:triton270 +REM Sage Attention and Triton for Torch 2.7.0 +"%cd%/venv/Scripts/pip.exe" install "triton-windows<3.4" --force-reinstall +"%cd%/venv/Scripts/pip.exe" install "https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp%pyv%-cp%pyv%-win_amd64.whl" --force-reinstall +echo Finishing up installing triton-windows. This requires extraction of libraries into Python Folder... + +REM Check for python version and download the triton-windows required libs accordingly +if %pyv% == 310 ( + powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.10.11_include_libs.zip', 'triton-lib.zip')" +) + +if %pyv% == 311 ( + powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.11.9_include_libs.zip', 'triton-lib.zip')" +) + +if %pyv% == 312 ( + powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.12.7_include_libs.zip', 'triton-lib.zip')" +) + +REM Extract the zip into the Python Folder and Delete zip +powershell Expand-Archive -Path '%cd%\triton-lib.zip' -DestinationPath '%cd%\venv\Scripts\' -force +del triton-lib.zip +if %both% == "Y" (goto flash270) + +goto requirements + +:flash270 +REM Install flash-attn. +"%cd%/venv/Scripts/pip.exe" install "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4.post1%%2Bcu128torch2.7.0cxx11abiFALSE-cp%pyv%-cp%pyv%-win_amd64.whl?download=true" +goto requirements + + +REM RTX 4000 Series and below +:torch260 +"%cd%/venv/Scripts/pip.exe" install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 --force-reinstall +REM Check if pip installation was successful +if %errorlevel% neq 0 ( + echo Warning: Failed to install dependencies. You may need to install them manually. + goto end +) + +REM Ask if user wants Sage Attention +set choice= +echo Do you want to install any of the following? They speed up generation. +echo 1) Sage Attention +echo 2) Flash Attention +echo 3) BOTH! +echo 4) No +set /p choice= "Input Selection: " + +set both="N" + +if "!choice!" == "1" (goto triton260) +if "!choice!"== "2" (goto flash260) +if "!choice!"== "3" (set both="Y" +goto triton260) + +goto requirements + +:triton260 +REM Sage Attention and Triton for Torch 2.6.0 +"%cd%/venv/Scripts/pip.exe" install "triton-windows<3.3.0" --force-reinstall +"%cd%/venv/Scripts/pip.exe" install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp%pyv%-cp%pyv%-win_amd64.whl --force-reinstall + +echo Finishing up installing triton-windows. This requires extraction of libraries into Python Folder... + +REM Check for python version and download the triton-windows required libs accordingly +if %pyv% == 310 ( + powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.10.11_include_libs.zip', 'triton-lib.zip')" +) + +if %pyv% == 311 ( + powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.11.9_include_libs.zip', 'triton-lib.zip')" +) + +if %pyv% == 312 ( + powershell -Command "(New-Object System.Net.WebClient).DownloadFile('https://github.com/woct0rdho/triton-windows/releases/download/v3.0.0-windows.post1/python_3.12.7_include_libs.zip', 'triton-lib.zip')" +) + +REM Extract the zip into the Python Folder and Delete zip +powershell Expand-Archive -Path '%cd%\triton-lib.zip' -DestinationPath '%cd%\venv\Scripts\' -force +del triton-lib.zip + +if %both% == "Y" (goto flash260) + +goto requirements + +:flash260 +REM Install flash-attn. +"%cd%/venv/Scripts/pip.exe" install "https://huggingface.co/lldacing/flash-attention-windows-wheel/resolve/main/flash_attn-2.7.4%%2Bcu126torch2.6.0cxx11abiFALSE-cp%pyv%-cp%pyv%-win_amd64.whl?download=true" + +:requirements +echo Installing remaining required packages through pip... +REM This assumes there's a requirements.txt file in the root +"%cd%/venv/Scripts/pip.exe" install -r requirements.txt + +REM Check if pip installation was successful +if %errorlevel% neq 0 ( + echo Warning: Failed to install dependencies. You may need to install them manually. + goto end +) + +echo Setup complete. + +:end +echo Exiting setup script. +pause diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4285601106e4a78ed920840ce2c9ecc5b2d2adb5 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,4 @@ +# modules/__init__.py + +# Workaround for the single lora bug. Must not be an empty string. +DUMMY_LORA_NAME = " " diff --git a/modules/generators/__init__.py b/modules/generators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20801fcdd8e98bf7504143b5e7457c1d86167ebb --- /dev/null +++ b/modules/generators/__init__.py @@ -0,0 +1,32 @@ +from .original_generator import OriginalModelGenerator +from .f1_generator import F1ModelGenerator +from .video_generator import VideoModelGenerator +from .video_f1_generator import VideoF1ModelGenerator +from .original_with_endframe_generator import OriginalWithEndframeModelGenerator + +def create_model_generator(model_type, **kwargs): + """ + Create a model generator based on the model type. + + Args: + model_type: The type of model to create ("Original", "Original with Endframe", "F1", "Video", or "Video F1") + **kwargs: Additional arguments to pass to the model generator constructor + + Returns: + A model generator instance + + Raises: + ValueError: If the model type is not supported + """ + if model_type == "Original": + return OriginalModelGenerator(**kwargs) + elif model_type == "Original with Endframe": + return OriginalWithEndframeModelGenerator(**kwargs) + elif model_type == "F1": + return F1ModelGenerator(**kwargs) + elif model_type == "Video": + return VideoModelGenerator(**kwargs) + elif model_type == "Video F1": + return VideoF1ModelGenerator(**kwargs) + else: + raise ValueError(f"Unsupported model type: {model_type}") diff --git a/modules/generators/base_generator.py b/modules/generators/base_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..175dd08f17a359cfe47f8cd0c125424c9a56167e --- /dev/null +++ b/modules/generators/base_generator.py @@ -0,0 +1,281 @@ +import torch +import os # required for os.path +from abc import ABC, abstractmethod +from diffusers_helper import lora_utils +from typing import List, Optional +from pathlib import Path + +class BaseModelGenerator(ABC): + """ + Base class for model generators. + This defines the common interface that all model generators must implement. + """ + + def __init__(self, + text_encoder, + text_encoder_2, + tokenizer, + tokenizer_2, + vae, + image_encoder, + feature_extractor, + high_vram=False, + prompt_embedding_cache=None, + settings=None, + offline=False): # NEW: offline flag + """ + Initialize the base model generator. + + Args: + text_encoder: The text encoder model + text_encoder_2: The second text encoder model + tokenizer: The tokenizer for the first text encoder + tokenizer_2: The tokenizer for the second text encoder + vae: The VAE model + image_encoder: The image encoder model + feature_extractor: The feature extractor + high_vram: Whether high VRAM mode is enabled + prompt_embedding_cache: Cache for prompt embeddings + settings: Application settings + offline: Whether to run in offline mode for model loading + """ + self.text_encoder = text_encoder + self.text_encoder_2 = text_encoder_2 + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 + self.vae = vae + self.image_encoder = image_encoder + self.feature_extractor = feature_extractor + self.high_vram = high_vram + self.prompt_embedding_cache = prompt_embedding_cache or {} + self.settings = settings + self.offline = offline + self.transformer = None + self.gpu = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.cpu = torch.device("cpu") + + + @abstractmethod + def load_model(self): + """ + Load the transformer model. + This method should be implemented by each specific model generator. + """ + pass + + @abstractmethod + def get_model_name(self): + """ + Get the name of the model. + This method should be implemented by each specific model generator. + """ + pass + + @staticmethod + def _get_snapshot_hash_from_refs(model_repo_id_for_cache: str) -> str | None: + """ + Reads the commit hash from the refs/main file for a given model in the HF cache. + Args: + model_repo_id_for_cache (str): The model ID formatted for cache directory names + (e.g., "models--lllyasviel--FramePackI2V_HY"). + Returns: + str: The commit hash if found, otherwise None. + """ + hf_home_dir = os.environ.get('HF_HOME') + if not hf_home_dir: + print("Warning: HF_HOME environment variable not set. Cannot determine snapshot hash.") + return None + + refs_main_path = os.path.join(hf_home_dir, 'hub', model_repo_id_for_cache, 'refs', 'main') + if os.path.exists(refs_main_path): + try: + with open(refs_main_path, 'r') as f: + print(f"Offline mode: Reading snapshot hash from: {refs_main_path}") + return f.read().strip() + except Exception as e: + print(f"Warning: Could not read snapshot hash from {refs_main_path}: {e}") + return None + else: + print(f"Warning: refs/main file not found at {refs_main_path}. Cannot determine snapshot hash.") + return None + + def _get_offline_load_path(self) -> str: + """ + Returns the local snapshot path for offline loading if available. + Falls back to the default self.model_path if local snapshot can't be found. + Relies on self.model_repo_id_for_cache and self.model_path being set by subclasses. + """ + # Ensure necessary attributes are set by the subclass + if not hasattr(self, 'model_repo_id_for_cache') or not self.model_repo_id_for_cache: + print(f"Warning: model_repo_id_for_cache not set in {self.__class__.__name__}. Cannot determine offline path.") + # Fallback to model_path if it exists, otherwise None + return getattr(self, 'model_path', None) + + if not hasattr(self, 'model_path') or not self.model_path: + print(f"Warning: model_path not set in {self.__class__.__name__}. Cannot determine fallback for offline path.") + return None + + snapshot_hash = self._get_snapshot_hash_from_refs(self.model_repo_id_for_cache) + hf_home = os.environ.get('HF_HOME') + + if snapshot_hash and hf_home: + specific_snapshot_path = os.path.join( + hf_home, 'hub', self.model_repo_id_for_cache, 'snapshots', snapshot_hash + ) + if os.path.isdir(specific_snapshot_path): + return specific_snapshot_path + + # If snapshot logic fails or path is not a dir, fallback to the default model path + return self.model_path + + def unload_loras(self): + """ + Unload all LoRAs from the transformer model. + """ + if self.transformer is not None: + print(f"Unloading all LoRAs from {self.get_model_name()} model") + self.transformer = lora_utils.unload_all_loras(self.transformer) + self.verify_lora_state("After unloading LoRAs") + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def verify_lora_state(self, label=""): + """ + Debug function to verify the state of LoRAs in the transformer model. + """ + if self.transformer is None: + print(f"[{label}] Transformer is None, cannot verify LoRA state") + return + + has_loras = False + if hasattr(self.transformer, 'peft_config'): + adapter_names = list(self.transformer.peft_config.keys()) if self.transformer.peft_config else [] + if adapter_names: + has_loras = True + print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}") + else: + print(f"[{label}] Transformer has no LoRAs in peft_config") + else: + print(f"[{label}] Transformer has no peft_config attribute") + + # Check for any LoRA modules + for name, module in self.transformer.named_modules(): + if hasattr(module, 'lora_A') and module.lora_A: + has_loras = True + # print(f"[{label}] Found lora_A in module {name}") + if hasattr(module, 'lora_B') and module.lora_B: + has_loras = True + # print(f"[{label}] Found lora_B in module {name}") + + if not has_loras: + print(f"[{label}] No LoRA components found in transformer") + + def move_lora_adapters_to_device(self, target_device): + """ + Move all LoRA adapters in the transformer model to the specified device. + This handles the PEFT implementation of LoRA. + """ + if self.transformer is None: + return + + print(f"Moving all LoRA adapters to {target_device}") + + # First, find all modules with LoRA adapters + lora_modules = [] + for name, module in self.transformer.named_modules(): + if hasattr(module, 'active_adapter') and hasattr(module, 'lora_A') and hasattr(module, 'lora_B'): + lora_modules.append((name, module)) + + # Now move all LoRA components to the target device + for name, module in lora_modules: + # Get the active adapter name + active_adapter = module.active_adapter + + # Move the LoRA layers to the target device + if active_adapter is not None: + if isinstance(module.lora_A, torch.nn.ModuleDict): + # Handle ModuleDict case (PEFT implementation) + for adapter_name in list(module.lora_A.keys()): + # Move lora_A + if adapter_name in module.lora_A: + module.lora_A[adapter_name] = module.lora_A[adapter_name].to(target_device) + + # Move lora_B + if adapter_name in module.lora_B: + module.lora_B[adapter_name] = module.lora_B[adapter_name].to(target_device) + + # Move scaling + if hasattr(module, 'scaling') and isinstance(module.scaling, dict) and adapter_name in module.scaling: + if isinstance(module.scaling[adapter_name], torch.Tensor): + module.scaling[adapter_name] = module.scaling[adapter_name].to(target_device) + else: + # Handle direct attribute case + if hasattr(module, 'lora_A') and module.lora_A is not None: + module.lora_A = module.lora_A.to(target_device) + if hasattr(module, 'lora_B') and module.lora_B is not None: + module.lora_B = module.lora_B.to(target_device) + if hasattr(module, 'scaling') and module.scaling is not None: + if isinstance(module.scaling, torch.Tensor): + module.scaling = module.scaling.to(target_device) + + print(f"Moved all LoRA adapters to {target_device}") + + def load_loras(self, selected_loras: List[str], lora_folder: str, lora_loaded_names: List[str], lora_values: Optional[List[float]] = None): + """ + Load LoRAs into the transformer model and applies their weights. + + Args: + selected_loras: List of LoRA base names to load (e.g., ["lora_A", "lora_B"]). + lora_folder: Path to the folder containing the LoRA files. + lora_loaded_names: The master list of ALL available LoRA names, used for correct weight indexing. + lora_values: A list of strength values corresponding to lora_loaded_names. + """ + self.unload_loras() + + if not selected_loras: + print("No LoRAs selected, skipping loading.") + return + + lora_dir = Path(lora_folder) + + adapter_names = [] + strengths = [] + + for idx, lora_base_name in enumerate(selected_loras): + lora_file = None + for ext in (".safetensors", ".pt"): + candidate_path_relative = f"{lora_base_name}{ext}" + candidate_path_full = lora_dir / candidate_path_relative + if candidate_path_full.is_file(): + lora_file = candidate_path_relative + break + + if not lora_file: + print(f"Warning: LoRA file for base name '{lora_base_name}' not found; skipping.") + continue + + print(f"Loading LoRA from '{lora_file}'...") + + self.transformer, adapter_name = lora_utils.load_lora(self.transformer, lora_dir, lora_file) + adapter_names.append(adapter_name) + + weight = 1.0 + if lora_values: + try: + master_list_idx = lora_loaded_names.index(lora_base_name) + if master_list_idx < len(lora_values): + weight = float(lora_values[master_list_idx]) + else: + print(f"Warning: Index mismatch for '{lora_base_name}'. Defaulting to 1.0.") + except ValueError: + print(f"Warning: LoRA '{lora_base_name}' not found in master list. Defaulting to 1.0.") + + strengths.append(weight) + + if adapter_names: + print(f"Activating adapters: {adapter_names} with strengths: {strengths}") + lora_utils.set_adapters(self.transformer, adapter_names, strengths) + + self.verify_lora_state("After completing load_loras") \ No newline at end of file diff --git a/modules/generators/f1_generator.py b/modules/generators/f1_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..e9b4224aa78e104930328c43e4ecd1318646e649 --- /dev/null +++ b/modules/generators/f1_generator.py @@ -0,0 +1,235 @@ +import torch +import os # for offline loading path +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.memory import DynamicSwapInstaller +from .base_generator import BaseModelGenerator + +class F1ModelGenerator(BaseModelGenerator): + """ + Model generator for the F1 HunyuanVideo model. + """ + + def __init__(self, **kwargs): + """ + Initialize the F1 model generator. + """ + super().__init__(**kwargs) + self.model_name = "F1" + self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503' + self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503" + + def get_model_name(self): + """ + Get the name of the model. + """ + return self.model_name + + def load_model(self): + """ + Load the F1 transformer model. + If offline mode is True, attempts to load from a local snapshot. + """ + print(f"Loading {self.model_name} Transformer...") + + path_to_load = self.model_path # Initialize with the default path + + if self.offline: + path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator + + # Create the transformer model + self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained( + path_to_load, + torch_dtype=torch.bfloat16 + ).cpu() + + # Configure the model + self.transformer.eval() + self.transformer.to(dtype=torch.bfloat16) + self.transformer.requires_grad_(False) + + # Set up dynamic swap if not in high VRAM mode + if not self.high_vram: + DynamicSwapInstaller.install_model(self.transformer, device=self.gpu) + else: + # In high VRAM mode, move the entire model to GPU + self.transformer.to(device=self.gpu) + + print(f"{self.model_name} Transformer Loaded from {path_to_load}.") + return self.transformer + + def prepare_history_latents(self, height, width): + """ + Prepare the history latents tensor for the F1 model. + + Args: + height: The height of the image + width: The width of the image + + Returns: + The initialized history latents tensor + """ + return torch.zeros( + size=(1, 16, 16 + 2 + 1, height // 8, width // 8), + dtype=torch.float32 + ).cpu() + + def initialize_with_start_latent(self, history_latents, start_latent, is_real_image_latent): + """ + Initialize the history latents with the start latent for the F1 model. + + Args: + history_latents: The history latents + start_latent: The start latent + is_real_image_latent: Whether the start latent came from a real input image or is a synthetic noise + + Returns: + The initialized history latents + """ + # Add the start frame to history_latents + if is_real_image_latent: + return torch.cat([history_latents, start_latent.to(history_latents)], dim=2) + # After prepare_history_latents, history_latents (initialized with zeros) + # already has the required 19 entries for initial clean latents + return history_latents + + def get_latent_paddings(self, total_latent_sections): + """ + Get the latent paddings for the F1 model. + + Args: + total_latent_sections: The total number of latent sections + + Returns: + A list of latent paddings + """ + # F1 model uses a fixed approach with just 0 for last section and 1 for others + return [1] * (total_latent_sections - 1) + [0] + + def prepare_indices(self, latent_padding_size, latent_window_size): + """ + Prepare the indices for the F1 model. + + Args: + latent_padding_size: The size of the latent padding + latent_window_size: The size of the latent window + + Returns: + A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices) + """ + # F1 model uses a different indices approach + # latent_window_sizeが4.5の場合は特別に5を使用 + effective_window_size = 5 if latent_window_size == 4.5 else int(latent_window_size) + indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0) + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split([1, 16, 2, 1, latent_window_size], dim=1) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) + + return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices + + def prepare_clean_latents(self, start_latent, history_latents): + """ + Prepare the clean latents for the F1 model. + + Args: + start_latent: The start latent + history_latents: The history latents + + Returns: + A tuple of (clean_latents, clean_latents_2x, clean_latents_4x) + """ + # For F1, we take the last frames for clean latents + clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]):, :, :].split([16, 2, 1], dim=2) + # For F1, we prepend the start latent to clean_latents_1x + clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2) + + return clean_latents, clean_latents_2x, clean_latents_4x + + def update_history_latents(self, history_latents, generated_latents): + """ + Update the history latents with the generated latents for the F1 model. + + Args: + history_latents: The history latents + generated_latents: The generated latents + + Returns: + The updated history latents + """ + # For F1, we append new frames to the end + return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2) + + def get_real_history_latents(self, history_latents, total_generated_latent_frames): + """ + Get the real history latents for the F1 model. + + Args: + history_latents: The history latents + total_generated_latent_frames: The total number of generated latent frames + + Returns: + The real history latents + """ + # For F1, we take frames from the end + return history_latents[:, :, -total_generated_latent_frames:, :, :] + + def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames): + """ + Update the history pixels with the current pixels for the F1 model. + + Args: + history_pixels: The history pixels + current_pixels: The current pixels + overlapped_frames: The number of overlapped frames + + Returns: + The updated history pixels + """ + from diffusers_helper.utils import soft_append_bcthw + # For F1 model, history_pixels is first, current_pixels is second + return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames) + + def get_section_latent_frames(self, latent_window_size, is_last_section): + """ + Get the number of section latent frames for the F1 model. + + Args: + latent_window_size: The size of the latent window + is_last_section: Whether this is the last section + + Returns: + The number of section latent frames + """ + return latent_window_size * 2 + + def get_current_pixels(self, real_history_latents, section_latent_frames, vae): + """ + Get the current pixels for the F1 model. + + Args: + real_history_latents: The real history latents + section_latent_frames: The number of section latent frames + vae: The VAE model + + Returns: + The current pixels + """ + from diffusers_helper.hunyuan import vae_decode + # For F1, we take frames from the end + return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu() + + def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt): + """ + Format the position description for the F1 model. + + Args: + total_generated_latent_frames: The total number of generated latent frames + current_pos: The current position in seconds + original_pos: The original position in seconds + current_prompt: The current prompt + + Returns: + The formatted position description + """ + return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, ' + f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). ' + f'Current position: {current_pos:.2f}s. ' + f'using prompt: {current_prompt[:256]}...') diff --git a/modules/generators/original_generator.py b/modules/generators/original_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..9a98d42385d4d44465e0862b70ce54e68e883e4c --- /dev/null +++ b/modules/generators/original_generator.py @@ -0,0 +1,213 @@ +import torch +import os # for offline loading path +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.memory import DynamicSwapInstaller +from .base_generator import BaseModelGenerator + +class OriginalModelGenerator(BaseModelGenerator): + """ + Model generator for the Original HunyuanVideo model. + """ + + def __init__(self, **kwargs): + """ + Initialize the Original model generator. + """ + super().__init__(**kwargs) + self.model_name = "Original" + self.model_path = 'lllyasviel/FramePackI2V_HY' + self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY" + + def get_model_name(self): + """ + Get the name of the model. + """ + return self.model_name + + def load_model(self): + """ + Load the Original transformer model. + If offline mode is True, attempts to load from a local snapshot. + """ + print(f"Loading {self.model_name} Transformer...") + + path_to_load = self.model_path # Initialize with the default path + + if self.offline: + path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator + + # Create the transformer model + self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained( + path_to_load, + torch_dtype=torch.bfloat16 + ).cpu() + + # Configure the model + self.transformer.eval() + self.transformer.to(dtype=torch.bfloat16) + self.transformer.requires_grad_(False) + + # Set up dynamic swap if not in high VRAM mode + if not self.high_vram: + DynamicSwapInstaller.install_model(self.transformer, device=self.gpu) + else: + # In high VRAM mode, move the entire model to GPU + self.transformer.to(device=self.gpu) + + print(f"{self.model_name} Transformer Loaded from {path_to_load}.") + return self.transformer + + def prepare_history_latents(self, height, width): + """ + Prepare the history latents tensor for the Original model. + + Args: + height: The height of the image + width: The width of the image + + Returns: + The initialized history latents tensor + """ + return torch.zeros( + size=(1, 16, 1 + 2 + 16, height // 8, width // 8), + dtype=torch.float32 + ).cpu() + + def get_latent_paddings(self, total_latent_sections): + """ + Get the latent paddings for the Original model. + + Args: + total_latent_sections: The total number of latent sections + + Returns: + A list of latent paddings + """ + # Original model uses reversed latent paddings + if total_latent_sections > 4: + return [3] + [2] * (total_latent_sections - 3) + [1, 0] + else: + return list(reversed(range(total_latent_sections))) + + def prepare_indices(self, latent_padding_size, latent_window_size): + """ + Prepare the indices for the Original model. + + Args: + latent_padding_size: The size of the latent padding + latent_window_size: The size of the latent window + + Returns: + A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices) + """ + indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0) + clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1) + clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) + + return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices + + def prepare_clean_latents(self, start_latent, history_latents): + """ + Prepare the clean latents for the Original model. + + Args: + start_latent: The start latent + history_latents: The history latents + + Returns: + A tuple of (clean_latents, clean_latents_2x, clean_latents_4x) + """ + clean_latents_pre = start_latent.to(history_latents) + clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, :1 + 2 + 16, :, :].split([1, 2, 16], dim=2) + clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + + return clean_latents, clean_latents_2x, clean_latents_4x + + def update_history_latents(self, history_latents, generated_latents): + """ + Update the history latents with the generated latents for the Original model. + + Args: + history_latents: The history latents + generated_latents: The generated latents + + Returns: + The updated history latents + """ + # For Original model, we prepend the generated latents + return torch.cat([generated_latents.to(history_latents), history_latents], dim=2) + + def get_real_history_latents(self, history_latents, total_generated_latent_frames): + """ + Get the real history latents for the Original model. + + Args: + history_latents: The history latents + total_generated_latent_frames: The total number of generated latent frames + + Returns: + The real history latents + """ + return history_latents[:, :, :total_generated_latent_frames, :, :] + + def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames): + """ + Update the history pixels with the current pixels for the Original model. + + Args: + history_pixels: The history pixels + current_pixels: The current pixels + overlapped_frames: The number of overlapped frames + + Returns: + The updated history pixels + """ + from diffusers_helper.utils import soft_append_bcthw + # For Original model, current_pixels is first, history_pixels is second + return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) + + def get_section_latent_frames(self, latent_window_size, is_last_section): + """ + Get the number of section latent frames for the Original model. + + Args: + latent_window_size: The size of the latent window + is_last_section: Whether this is the last section + + Returns: + The number of section latent frames + """ + return latent_window_size * 2 + + def get_current_pixels(self, real_history_latents, section_latent_frames, vae): + """ + Get the current pixels for the Original model. + + Args: + real_history_latents: The real history latents + section_latent_frames: The number of section latent frames + vae: The VAE model + + Returns: + The current pixels + """ + from diffusers_helper.hunyuan import vae_decode + return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu() + + def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt): + """ + Format the position description for the Original model. + + Args: + total_generated_latent_frames: The total number of generated latent frames + current_pos: The current position in seconds + original_pos: The original position in seconds + current_prompt: The current prompt + + Returns: + The formatted position description + """ + return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, ' + f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). ' + f'Current position: {current_pos:.2f}s (original: {original_pos:.2f}s). ' + f'using prompt: {current_prompt[:256]}...') diff --git a/modules/generators/original_with_endframe_generator.py b/modules/generators/original_with_endframe_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..679ab5e23d7c4083f27e78f31a3e0dc4612d657a --- /dev/null +++ b/modules/generators/original_with_endframe_generator.py @@ -0,0 +1,15 @@ +from .original_generator import OriginalModelGenerator + +class OriginalWithEndframeModelGenerator(OriginalModelGenerator): + """ + Model generator for the Original HunyuanVideo model with end frame support. + This extends the Original model with the ability to guide generation toward a specified end frame. + """ + + def __init__(self, **kwargs): + """ + Initialize the Original with Endframe model generator. + """ + super().__init__(**kwargs) + self.model_name = "Original with Endframe" + # Inherits everything else from OriginalModelGenerator diff --git a/modules/generators/video_base_generator.py b/modules/generators/video_base_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2aef568249684edc25d9fa902761f1c66a1ef7ab --- /dev/null +++ b/modules/generators/video_base_generator.py @@ -0,0 +1,613 @@ +import torch +import os +import numpy as np +import math +import decord +from tqdm import tqdm +import pathlib +from PIL import Image + +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.memory import DynamicSwapInstaller +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from diffusers_helper.hunyuan import vae_encode, vae_decode +from .base_generator import BaseModelGenerator + +class VideoBaseModelGenerator(BaseModelGenerator): + """ + Model generator for the Video extension of the Original HunyuanVideo model. + This generator accepts video input instead of a single image. + """ + + def __init__(self, **kwargs): + """ + Initialize the Video model generator. + """ + super().__init__(**kwargs) + self.model_name = None # Subclass Model Specific + self.model_path = None # Subclass Model Specific + self.model_repo_id_for_cache = None # Subclass Model Specific + self.full_video_latents = None # For context, set by worker() when available + self.resolution = 640 # Default resolution + self.no_resize = False # Default to resize + self.vae_batch_size = 16 # Default VAE batch size + + # Import decord and tqdm here to avoid import errors if not installed + try: + import decord + from tqdm import tqdm + self.decord = decord + self.tqdm = tqdm + except ImportError: + print("Warning: decord or tqdm not installed. Video processing will not work.") + self.decord = None + self.tqdm = None + + def get_model_name(self): + """ + Get the name of the model. + """ + return self.model_name + + def load_model(self): + """ + Load the Video transformer model. + If offline mode is True, attempts to load from a local snapshot. + """ + print(f"Loading {self.model_name} Transformer...") + + path_to_load = self.model_path # Initialize with the default path + + if self.offline: + path_to_load = self._get_offline_load_path() # Calls the method in BaseModelGenerator + + # Create the transformer model + self.transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained( + path_to_load, + torch_dtype=torch.bfloat16 + ).cpu() + + # Configure the model + self.transformer.eval() + self.transformer.to(dtype=torch.bfloat16) + self.transformer.requires_grad_(False) + + # Set up dynamic swap if not in high VRAM mode + if not self.high_vram: + DynamicSwapInstaller.install_model(self.transformer, device=self.gpu) + else: + # In high VRAM mode, move the entire model to GPU + self.transformer.to(device=self.gpu) + + print(f"{self.model_name} Transformer Loaded from {path_to_load}.") + return self.transformer + + def min_real_frames_to_encode(self, real_frames_available_count): + """ + Minimum number of real frames to encode + is the maximum number of real frames used for generation context. + + The number of latents could be calculated as below for video F1, but keeping it simple for now + by hardcoding the Video F1 value at max_latents_used_for_context = 27. + + # Calculate the number of latent frames to encode from the end of the input video + num_frames_to_encode_from_end = 1 # Default minimum + if model_type == "Video": + # Max needed is 1 (clean_latent_pre) + 2 (max 2x) + 16 (max 4x) = 19 + num_frames_to_encode_from_end = 19 + elif model_type == "Video F1": + ui_num_cleaned_frames = job_params.get('num_cleaned_frames', 5) + # Max effective_clean_frames based on VideoF1ModelGenerator's logic. + # Max num_clean_frames from UI is 10 (modules/interface.py). + # Max effective_clean_frames = 10 - 1 = 9. + # total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames + # Max needed = 16 (max 4x) + 2 (max 2x) + 9 (max effective_clean_frames) = 27 + num_frames_to_encode_from_end = 27 + + Note: 27 latents ~ 108 real frames = 3.6 seconds at 30 FPS. + Note: 19 latents ~ 76 real frames ~ 2.5 seconds at 30 FPS. + """ + + max_latents_used_for_context = 27 + if self.get_model_name() == "Video": + max_latents_used_for_context = 27 # Weird results on 19 + elif self.get_model_name() == "Video F1": + max_latents_used_for_context = 27 # Enough for even Video F1 with cleaned_frames input of 10 + else: + print("======================================================") + print(f" ***** Warning: Unsupported video extension model type: {self.get_model_name()}.") + print( " ***** Using default max latents {max_latents_used_for_context} for context.") + print( " ***** Please report to the developers if you see this message:") + print( " ***** Discord: https://discord.gg/8Z2c3a4 or GitHub: https://github.com/colinurbs/FramePack-Studio") + print("======================================================") + # Probably better to press on with Video F1 max vs exception? + # raise ValueError(f"Unsupported video extension model type: {self.get_model_name()}") + + latent_size_factor = 4 # real frames to latent frames conversion factor + max_real_frames_used_for_context = max_latents_used_for_context * latent_size_factor + + # Shortest of available frames and max frames used for context + trimmed_real_frames_count = min(real_frames_available_count, max_real_frames_used_for_context) + if trimmed_real_frames_count < real_frames_available_count: + print(f"Truncating video frames from {real_frames_available_count} to {trimmed_real_frames_count}, enough to populate context") + + # Truncate to nearest latent size (multiple of 4) + frames_to_encode_count = (trimmed_real_frames_count // latent_size_factor) * latent_size_factor + if frames_to_encode_count != trimmed_real_frames_count: + print(f"Truncating video frames from {trimmed_real_frames_count} to {frames_to_encode_count}, for latent size compatibility") + + return frames_to_encode_count + + def extract_video_frames(self, is_for_encode, video_path, resolution, no_resize=False, input_files_dir=None): + """ + Extract real frames from a video, resized and center cropped as numpy array (T, H, W, C). + + Args: + is_for_encode: If True, results are capped at maximum frames used for context, and aligned to 4-frame latent requirement. + video_path: Path to the input video file. + resolution: Target resolution for resizing frames. + no_resize: Whether to use the original video resolution. + input_files_dir: Directory for input files that won't be cleaned up. + + Returns: + A tuple containing: + - input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C) + - fps: Frames per second of the input video + - target_height: Target height of the video + - target_width: Target width of the video + """ + def time_millis(): + import time + return time.perf_counter() * 1000.0 # Convert seconds to milliseconds + + encode_start_time_millis = time_millis() + + # Normalize video path for Windows compatibility + video_path = str(pathlib.Path(video_path).resolve()) + print(f"Processing video: {video_path}") + + # Check if the video is in the temp directory and if we have an input_files_dir + if input_files_dir and "temp" in video_path: + # Check if there's a copy of this video in the input_files_dir + filename = os.path.basename(video_path) + input_file_path = os.path.join(input_files_dir, filename) + + # If the file exists in input_files_dir, use that instead + if os.path.exists(input_file_path): + print(f"Using video from input_files_dir: {input_file_path}") + video_path = input_file_path + else: + # If not, copy it to input_files_dir to prevent it from being deleted + try: + from diffusers_helper.utils import generate_timestamp + safe_filename = f"{generate_timestamp()}_{filename}" + input_file_path = os.path.join(input_files_dir, safe_filename) + import shutil + shutil.copy2(video_path, input_file_path) + print(f"Copied video to input_files_dir: {input_file_path}") + video_path = input_file_path + except Exception as e: + print(f"Error copying video to input_files_dir: {e}") + + try: + # Load video and get FPS + print("Initializing VideoReader...") + vr = decord.VideoReader(video_path) + fps = vr.get_avg_fps() # Get input video FPS + num_real_frames = len(vr) + print(f"Video loaded: {num_real_frames} frames, FPS: {fps}") + + # Read frames + print("Reading video frames...") + + total_frames_in_video_file = len(vr) + if is_for_encode: + print(f"Using minimum real frames to encode: {self.min_real_frames_to_encode(total_frames_in_video_file)}") + num_real_frames = self.min_real_frames_to_encode(total_frames_in_video_file) + # else left as all frames -- len(vr) with no regard for trimming or latent alignment + + # RT_BORG: Retaining this commented code for reference. + # pftq encoder discarded truncated frames from the end of the video. + # frames = vr.get_batch(range(num_real_frames)).asnumpy() # Shape: (num_real_frames, height, width, channels) + + # RT_BORG: Retaining this commented code for reference. + # pftq retained the entire encoded video. + # Truncate to nearest latent size (multiple of 4) + # latent_size_factor = 4 + # num_frames = (num_real_frames // latent_size_factor) * latent_size_factor + # if num_frames != num_real_frames: + # print(f"Truncating video from {num_real_frames} to {num_frames} frames for latent size compatibility") + # num_real_frames = num_frames + + # Discard truncated frames from the beginning of the video, retaining the last num_real_frames + # This ensures a smooth transition from the input video to the generated video + start_frame_index = total_frames_in_video_file - num_real_frames + frame_indices_to_extract = range(start_frame_index, total_frames_in_video_file) + frames = vr.get_batch(frame_indices_to_extract).asnumpy() # Shape: (num_real_frames, height, width, channels) + + print(f"Frames read: {frames.shape}") + + # Get native video resolution + native_height, native_width = frames.shape[1], frames.shape[2] + print(f"Native video resolution: {native_width}x{native_height}") + + # Use native resolution if height/width not specified, otherwise use provided values + target_height = native_height + target_width = native_width + + # Adjust to nearest bucket for model compatibility + if not no_resize: + target_height, target_width = find_nearest_bucket(target_height, target_width, resolution=resolution) + print(f"Adjusted resolution: {target_width}x{target_height}") + else: + print(f"Using native resolution without resizing: {target_width}x{target_height}") + + # Preprocess input frames to match desired resolution + input_frames_resized_np = [] + for i, frame in tqdm(enumerate(frames), desc="Processing Video Frames", total=num_real_frames, mininterval=0.1): + frame_np = resize_and_center_crop(frame, target_width=target_width, target_height=target_height) + input_frames_resized_np.append(frame_np) + input_frames_resized_np = np.stack(input_frames_resized_np) # Shape: (num_real_frames, height, width, channels) + print(f"Frames preprocessed: {input_frames_resized_np.shape}") + + resized_frames_time_millis = time_millis() + if (False): # We really need a logger + print("======================================================") + memory_bytes = input_frames_resized_np.nbytes + memory_kb = memory_bytes / 1024 + memory_mb = memory_kb / 1024 + print(f" ***** input_frames_resized_np: {input_frames_resized_np.shape}") + print(f" ***** Memory usage: {int(memory_mb)} MB") + duration_ms = resized_frames_time_millis - encode_start_time_millis + print(f" ***** Time taken to process frames tensor: {duration_ms / 1000.0:.2f} seconds") + print("======================================================") + + return input_frames_resized_np, fps, target_height, target_width + except Exception as e: + print(f"Error in extract_video_frames: {str(e)}") + raise + + # RT_BORG: video_encode produce and return end_of_input_video_latent and end_of_input_video_image_np + # which are not needed for Video models without end frame processing. + # But these should be inexpensive and it's easier to keep the code uniform. + @torch.no_grad() + def video_encode(self, video_path, resolution, no_resize=False, vae_batch_size=16, device=None, input_files_dir=None): + """ + Encode a video into latent representations using the VAE. + + Args: + video_path: Path to the input video file. + resolution: Target resolution for resizing frames. + no_resize: Whether to use the original video resolution. + vae_batch_size: Number of frames to process per batch. + device: Device for computation (e.g., "cuda"). + input_files_dir: Directory for input files that won't be cleaned up. + + Returns: + A tuple containing: + - start_latent: Latent of the first frame + - input_image_np: First frame as numpy array + - history_latents: Latents of all frames + - fps: Frames per second of the input video + - target_height: Target height of the video + - target_width: Target width of the video + - input_video_pixels: Video frames as tensor + - end_of_input_video_image_np: Last frame as numpy array + - input_frames_resized_np: All input frames resized and center cropped as numpy array (T, H, W, C) + """ + encoding = True # Flag to indicate this is for encoding + input_frames_resized_np, fps, target_height, target_width = self.extract_video_frames(encoding, video_path, resolution, no_resize, input_files_dir) + + try: + if device is None: + device = self.gpu + + # Check CUDA availability and fallback to CPU if needed + if device == "cuda" and not torch.cuda.is_available(): + print("CUDA is not available, falling back to CPU") + device = "cpu" + + # Save first frame for CLIP vision encoding + input_image_np = input_frames_resized_np[0] + end_of_input_video_image_np = input_frames_resized_np[-1] + + # Convert to tensor and normalize to [-1, 1] + print("Converting frames to tensor...") + frames_pt = torch.from_numpy(input_frames_resized_np).float() / 127.5 - 1 + frames_pt = frames_pt.permute(0, 3, 1, 2) # Shape: (num_real_frames, channels, height, width) + frames_pt = frames_pt.unsqueeze(0) # Shape: (1, num_real_frames, channels, height, width) + frames_pt = frames_pt.permute(0, 2, 1, 3, 4) # Shape: (1, channels, num_real_frames, height, width) + print(f"Tensor shape: {frames_pt.shape}") + + # Save pixel frames for use in worker + input_video_pixels = frames_pt.cpu() + + # Move to device + print(f"Moving tensor to device: {device}") + frames_pt = frames_pt.to(device) + print("Tensor moved to device") + + # Move VAE to device + print(f"Moving VAE to device: {device}") + self.vae.to(device) + print("VAE moved to device") + + # Encode frames in batches + print(f"Encoding input video frames in VAE batch size {vae_batch_size}") + latents = [] + self.vae.eval() + with torch.no_grad(): + frame_count = frames_pt.shape[2] + step_count = math.ceil(frame_count / vae_batch_size) + for i in tqdm(range(0, frame_count, vae_batch_size), desc="Encoding video frames", total=step_count, mininterval=0.1): + batch = frames_pt[:, :, i:i + vae_batch_size] # Shape: (1, channels, batch_size, height, width) + try: + # Log GPU memory before encoding + if device == "cuda": + free_mem = torch.cuda.memory_allocated() / 1024**3 + batch_latent = vae_encode(batch, self.vae) + # Synchronize CUDA to catch issues + if device == "cuda": + torch.cuda.synchronize() + latents.append(batch_latent) + except RuntimeError as e: + print(f"Error during VAE encoding: {str(e)}") + if device == "cuda" and "out of memory" in str(e).lower(): + print("CUDA out of memory, try reducing vae_batch_size or using CPU") + raise + + # Concatenate latents + print("Concatenating latents...") + history_latents = torch.cat(latents, dim=2) # Shape: (1, channels, frames, height//8, width//8) + print(f"History latents shape: {history_latents.shape}") + + # Get first frame's latent + start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8) + print(f"Start latent shape: {start_latent.shape}") + + if (False): # We really need a logger + print("======================================================") + memory_bytes = history_latents.nbytes + memory_kb = memory_bytes / 1024 + memory_mb = memory_kb / 1024 + print(f" ***** history_latents: {history_latents.shape}") + print(f" ***** Memory usage: {int(memory_mb)} MB") + print("======================================================") + + # Move VAE back to CPU to free GPU memory + if device == "cuda": + self.vae.to(self.cpu) + torch.cuda.empty_cache() + print("VAE moved back to CPU, CUDA cache cleared") + + return start_latent, input_image_np, history_latents, fps, target_height, target_width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np + + except Exception as e: + print(f"Error in video_encode: {str(e)}") + raise + + # RT_BORG: Currently history_latents is initialized within worker() for all Video models as history_latents = video_latents + # So it is a coding error to call prepare_history_latents() here. + # Leaving in place as we will likely use it post-refactoring. + def prepare_history_latents(self, height, width): + """ + Prepare the history latents tensor for the Video model. + + Args: + height: The height of the image + width: The width of the image + + Returns: + The initialized history latents tensor + """ + raise TypeError( + f"Error: '{self.__class__.__name__}.prepare_history_latents' should not be called " + "on the Video models. history_latents should be initialized within worker() for all Video models " + "as history_latents = video_latents." + ) + + def prepare_indices(self, latent_padding_size, latent_window_size): + """ + Prepare the indices for the Video model. + + Args: + latent_padding_size: The size of the latent padding + latent_window_size: The size of the latent window + + Returns: + A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices) + """ + raise TypeError( + f"Error: '{self.__class__.__name__}.prepare_indices' should not be called " + "on the Video models. Currently video models each have a combined method: _prepare_clean_latents_and_indices." + ) + + def set_full_video_latents(self, video_latents): + """ + Set the full video latents for context. + + Args: + video_latents: The full video latents + """ + self.full_video_latents = video_latents + + def prepare_clean_latents(self, start_latent, history_latents): + """ + Prepare the clean latents for the Video model. + + Args: + start_latent: The start latent + history_latents: The history latents + + Returns: + A tuple of (clean_latents, clean_latents_2x, clean_latents_4x) + """ + raise TypeError( + f"Error: '{self.__class__.__name__}.prepare_indices' should not be called " + "on the Video models. Currently video models each have a combined method: _prepare_clean_latents_and_indices." + ) + + def get_section_latent_frames(self, latent_window_size, is_last_section): + """ + Get the number of section latent frames for the Video model. + + Args: + latent_window_size: The size of the latent window + is_last_section: Whether this is the last section + + Returns: + The number of section latent frames + """ + return latent_window_size * 2 + + def combine_videos(self, source_video_path, generated_video_path, output_path): + """ + Combine the source video with the generated video side by side. + + Args: + source_video_path: Path to the source video + generated_video_path: Path to the generated video + output_path: Path to save the combined video + + Returns: + Path to the combined video + """ + try: + import os + import subprocess + + print(f"Combining source video {source_video_path} with generated video {generated_video_path}") + + # Get the ffmpeg executable from the VideoProcessor class + from modules.toolbox.toolbox_processor import VideoProcessor + from modules.toolbox.message_manager import MessageManager + + # Create a message manager for logging + message_manager = MessageManager() + + # Import settings from main module + try: + from __main__ import settings + video_processor = VideoProcessor(message_manager, settings.settings) + except ImportError: + # Fallback to creating a new settings object + from modules.settings import Settings + settings = Settings() + video_processor = VideoProcessor(message_manager, settings.settings) + + # Get the ffmpeg executable + ffmpeg_exe = video_processor.ffmpeg_exe + + if not ffmpeg_exe: + print("FFmpeg executable not found. Cannot combine videos.") + return None + + print(f"Using ffmpeg at: {ffmpeg_exe}") + + # Create a temporary directory for the filter script + import tempfile + temp_dir = tempfile.gettempdir() + filter_script_path = os.path.join(temp_dir, f"filter_script_{os.path.basename(output_path)}.txt") + + # Get video dimensions using ffprobe + def get_video_info(video_path): + cmd = [ + ffmpeg_exe, "-i", video_path, + "-hide_banner", "-loglevel", "error" + ] + + # Run ffmpeg to get video info (it will fail but output info to stderr) + result = subprocess.run(cmd, capture_output=True, text=True) + + # Parse the output to get dimensions + width = height = None + for line in result.stderr.split('\n'): + if 'Video:' in line: + # Look for dimensions like 640x480 + import re + match = re.search(r'(\d+)x(\d+)', line) + if match: + width = int(match.group(1)) + height = int(match.group(2)) + break + + return width, height + + # Get dimensions of both videos + source_width, source_height = get_video_info(source_video_path) + generated_width, generated_height = get_video_info(generated_video_path) + + if not source_width or not generated_width: + print("Error: Could not determine video dimensions") + return None + + print(f"Source video: {source_width}x{source_height}") + print(f"Generated video: {generated_width}x{generated_height}") + + # Calculate target dimensions (maintain aspect ratio) + target_height = max(source_height, generated_height) + source_target_width = int(source_width * (target_height / source_height)) + generated_target_width = int(generated_width * (target_height / generated_height)) + + # Create a complex filter for side-by-side display with labels + filter_complex = ( + f"[0:v]scale={source_target_width}:{target_height}[left];" + f"[1:v]scale={generated_target_width}:{target_height}[right];" + f"[left]drawtext=text='Source':x=({source_target_width}/2-50):y=20:fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5[left_text];" + f"[right]drawtext=text='Generated':x=({generated_target_width}/2-70):y=20:fontsize=24:fontcolor=white:box=1:boxcolor=black@0.5[right_text];" + f"[left_text][right_text]hstack=inputs=2[v]" + ) + + # Write the filter script to a file + with open(filter_script_path, 'w') as f: + f.write(filter_complex) + + # Build the ffmpeg command + cmd = [ + ffmpeg_exe, "-y", + "-i", source_video_path, + "-i", generated_video_path, + "-filter_complex_script", filter_script_path, + "-map", "[v]" + ] + + # Check if source video has audio + has_audio_cmd = [ + ffmpeg_exe, "-i", source_video_path, + "-hide_banner", "-loglevel", "error" + ] + audio_check = subprocess.run(has_audio_cmd, capture_output=True, text=True) + has_audio = "Audio:" in audio_check.stderr + + if has_audio: + cmd.extend(["-map", "0:a"]) + + # Add output options + cmd.extend([ + "-c:v", "libx264", + "-crf", "18", + "-preset", "medium" + ]) + + if has_audio: + cmd.extend(["-c:a", "aac"]) + + cmd.append(output_path) + + # Run the ffmpeg command + print(f"Running ffmpeg command: {' '.join(cmd)}") + subprocess.run(cmd, check=True, capture_output=True, text=True) + + # Clean up the filter script + if os.path.exists(filter_script_path): + os.remove(filter_script_path) + + print(f"Combined video saved to {output_path}") + return output_path + + except Exception as e: + print(f"Error combining videos: {str(e)}") + import traceback + traceback.print_exc() + return None diff --git a/modules/generators/video_f1_generator.py b/modules/generators/video_f1_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..651ec3652e21b75cb950c0061dd3dfd8c2513482 --- /dev/null +++ b/modules/generators/video_f1_generator.py @@ -0,0 +1,189 @@ +import torch +import os +import numpy as np +import math +import decord +from tqdm import tqdm +import pathlib +from PIL import Image + +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.memory import DynamicSwapInstaller +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from diffusers_helper.hunyuan import vae_encode, vae_decode +from .video_base_generator import VideoBaseModelGenerator + +class VideoF1ModelGenerator(VideoBaseModelGenerator): + """ + Model generator for the Video F1 (forward video) extension of the F1 HunyuanVideo model. + These generators accept video input instead of a single image. + """ + + def __init__(self, **kwargs): + """ + Initialize the Video F1 model generator. + """ + super().__init__(**kwargs) + self.model_name = "Video F1" + self.model_path = 'lllyasviel/FramePack_F1_I2V_HY_20250503' # Same as F1 + self.model_repo_id_for_cache = "models--lllyasviel--FramePack_F1_I2V_HY_20250503" # Same as F1 + + def get_latent_paddings(self, total_latent_sections): + """ + Get the latent paddings for the Video model. + + Args: + total_latent_sections: The total number of latent sections + + Returns: + A list of latent paddings + """ + # RT_BORG: pftq didn't even use latent paddings in the forward Video model. Keeping it for consistency. + # Any list the size of total_latent_sections should work, but may as well end with 0 as a marker for the last section. + # Similar to F1 model uses a fixed approach with just 0 for last section and 1 for others + return [1] * (total_latent_sections - 1) + [0] + + def video_f1_prepare_clean_latents_and_indices(self, latent_window_size, video_latents, history_latents, num_cleaned_frames=5): + """ + Combined method to prepare clean latents and indices for the Video model. + + Args: + Work in progress - better not to pass in latent_paddings and latent_padding. + + Returns: + A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x) + """ + # Get num_cleaned_frames from job_params if available, otherwise use default value of 5 + num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5 + + # RT_BORG: Retaining this commented code for reference. + # start_latent = history_latents[:, :, :1] # Shape: (1, channels, 1, height//8, width//8) + start_latent = video_latents[:, :, -1:] # Shape: (1, channels, 1, height//8, width//8) + + available_frames = history_latents.shape[2] # Number of latent frames + max_pixel_frames = min(latent_window_size * 4 - 3, available_frames * 4) # Cap at available pixel frames + adjusted_latent_frames = max(1, (max_pixel_frames + 3) // 4) # Convert back to latent frames + # Adjust num_clean_frames to match original behavior: num_clean_frames=2 means 1 frame for clean_latents_1x + effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 0 + effective_clean_frames = min(effective_clean_frames, available_frames - 2) if available_frames > 2 else 0 # 20250507 pftq: changed 1 to 2 for edge case for <=1 sec videos + num_2x_frames = min(2, max(1, available_frames - effective_clean_frames - 1)) if available_frames > effective_clean_frames + 1 else 0 # 20250507 pftq: subtracted 1 for edge case for <=1 sec videos + num_4x_frames = min(16, max(1, available_frames - effective_clean_frames - num_2x_frames)) if available_frames > effective_clean_frames + num_2x_frames else 0 # 20250507 pftq: Edge case for <=1 sec + + total_context_frames = num_4x_frames + num_2x_frames + effective_clean_frames + total_context_frames = min(total_context_frames, available_frames) # 20250507 pftq: Edge case for <=1 sec videos + + indices = torch.arange(0, sum([1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames])).unsqueeze(0) # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos + clean_latent_indices_start, clean_latent_4x_indices, clean_latent_2x_indices, clean_latent_1x_indices, latent_indices = indices.split( + [1, num_4x_frames, num_2x_frames, effective_clean_frames, adjusted_latent_frames], dim=1 # 20250507 pftq: latent_window_size to adjusted_latent_frames for edge case for <=1 sec videos + ) + clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1) + + # 20250506 pftq: Split history_latents dynamically based on available frames + fallback_frame_count = 2 # 20250507 pftq: Changed 0 to 2 Edge case for <=1 sec videos + context_frames = history_latents[:, :, -total_context_frames:, :, :] if total_context_frames > 0 else history_latents[:, :, :fallback_frame_count, :, :] + if total_context_frames > 0: + split_sizes = [num_4x_frames, num_2x_frames, effective_clean_frames] + split_sizes = [s for s in split_sizes if s > 0] # Remove zero sizes + if split_sizes: + splits = context_frames.split(split_sizes, dim=2) + split_idx = 0 + clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :fallback_frame_count, :, :] + if clean_latents_4x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos + clean_latents_4x = torch.cat([clean_latents_4x, clean_latents_4x[:, :, -1:, :, :]], dim=2)[:, :, :2, :, :] + split_idx += 1 if num_4x_frames > 0 else 0 + clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :fallback_frame_count, :, :] + if clean_latents_2x.shape[2] < 2: # 20250507 pftq: edge case for <=1 sec videos + clean_latents_2x = torch.cat([clean_latents_2x, clean_latents_2x[:, :, -1:, :, :]], dim=2)[:, :, :2, :, :] + split_idx += 1 if num_2x_frames > 0 else 0 + clean_latents_1x = splits[split_idx] if effective_clean_frames > 0 and split_idx < len(splits) else history_latents[:, :, :fallback_frame_count, :, :] + else: + clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :] + else: + clean_latents_4x = clean_latents_2x = clean_latents_1x = history_latents[:, :, :fallback_frame_count, :, :] + + clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2) + + return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x + + def update_history_latents(self, history_latents, generated_latents): + """ + Forward Generation: Update the history latents with the generated latents for the Video F1 model. + + Args: + history_latents: The history latents + generated_latents: The generated latents + + Returns: + The updated history latents + """ + # For Video F1 model, we append the generated latents to the back of history latents + # This matches the F1 implementation + # It generates new sections forward in time, chunk by chunk + return torch.cat([history_latents, generated_latents.to(history_latents)], dim=2) + + def get_real_history_latents(self, history_latents, total_generated_latent_frames): + """ + Get the real history latents for the backward Video model. For Video, this is the first + `total_generated_latent_frames` frames of the history latents. + + Args: + history_latents: The history latents + total_generated_latent_frames: The total number of generated latent frames + + Returns: + The real history latents + """ + # Generated frames at the back. Note the difference in "-total_generated_latent_frames:". + return history_latents[:, :, -total_generated_latent_frames:, :, :] + + def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames): + """ + Update the history pixels with the current pixels for the Video model. + + Args: + history_pixels: The history pixels + current_pixels: The current pixels + overlapped_frames: The number of overlapped frames + + Returns: + The updated history pixels + """ + from diffusers_helper.utils import soft_append_bcthw + # For Video F1 model, we append the current pixels to the history pixels + # This matches the F1 model, history_pixels is first, current_pixels is second + return soft_append_bcthw(history_pixels, current_pixels, overlapped_frames) + + def get_current_pixels(self, real_history_latents, section_latent_frames, vae): + """ + Get the current pixels for the Video model. + + Args: + real_history_latents: The real history latents + section_latent_frames: The number of section latent frames + vae: The VAE model + + Returns: + The current pixels + """ + # For forward Video mode, current pixels are at the back of history, like F1. + return vae_decode(real_history_latents[:, :, -section_latent_frames:], vae).cpu() + + def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt): + """ + Format the position description for the Video model. + + Args: + total_generated_latent_frames: The total number of generated latent frames + current_pos: The current position in seconds (includes input video time) + original_pos: The original position in seconds + current_prompt: The current prompt + + Returns: + The formatted position description + """ + # RT_BORG: Duplicated from F1. Is this correct? + return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, ' + f'Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). ' + f'Current position: {current_pos:.2f}s. ' + f'using prompt: {current_prompt[:256]}...') diff --git a/modules/generators/video_generator.py b/modules/generators/video_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..45112745c068ae53bb600257a214f33e6c71937c --- /dev/null +++ b/modules/generators/video_generator.py @@ -0,0 +1,239 @@ +import torch +import os +import numpy as np +import math +import decord +from tqdm import tqdm +import pathlib +from PIL import Image + +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.memory import DynamicSwapInstaller +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from diffusers_helper.hunyuan import vae_encode, vae_decode +from .video_base_generator import VideoBaseModelGenerator + +class VideoModelGenerator(VideoBaseModelGenerator): + """ + Generator for the Video (backward) extension of the Original HunyuanVideo model. + These generators accept video input instead of a single image. + """ + + def __init__(self, **kwargs): + """ + Initialize the Video model generator. + """ + super().__init__(**kwargs) + self.model_name = "Video" + self.model_path = 'lllyasviel/FramePackI2V_HY' # Same as Original + self.model_repo_id_for_cache = "models--lllyasviel--FramePackI2V_HY" + + def get_latent_paddings(self, total_latent_sections): + """ + Get the latent paddings for the Video model. + + Args: + total_latent_sections: The total number of latent sections + + Returns: + A list of latent paddings + """ + # Video model uses reversed latent paddings like Original + if total_latent_sections > 4: + return [3] + [2] * (total_latent_sections - 3) + [1, 0] + else: + return list(reversed(range(total_latent_sections))) + + def video_prepare_clean_latents_and_indices(self, end_frame_output_dimensions_latent, end_frame_weight, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames=5): + """ + Combined method to prepare clean latents and indices for the Video model. + + Args: + Work in progress - better not to pass in latent_paddings and latent_padding. + num_cleaned_frames: Number of context frames to use from the video (adherence to video) + + Returns: + A tuple of (clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x) + """ + # Get num_cleaned_frames from job_params if available, otherwise use default value of 5 + num_clean_frames = num_cleaned_frames if num_cleaned_frames is not None else 5 + + + # HACK SOME STUFF IN THAT SHOULD NOT BE HERE + # Placeholders for end frame processing + # Colin, I'm only leaving them for the moment in case you want separate models for + # Video-backward and Video-backward-Endframe. + # end_latent = None + # end_of_input_video_embedding = None # Placeholder for end frame's CLIP embedding. SEE: 20250507 pftq: Process end frame if provided + # end_clip_embedding = None # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided + # end_frame_weight = 0.0 # Placeholders for end frame processing. SEE: 20250507 pftq: Process end frame if provided + # HACK MORE STUFF IN THAT PROBABLY SHOULD BE ARGUMENTS OR OTHWISE MADE AVAILABLE + end_of_input_video_latent = video_latents[:, :, -1:] # Last frame of the input video (produced by video_encode in the PR) + is_start_of_video = latent_padding == 0 # This refers to the start of the *generated* video part + is_end_of_video = latent_padding == latent_paddings[0] # This refers to the end of the *generated* video part (closest to input video) (better not to pass in latent_paddings[]) + # End of HACK STUFF + + # Dynamic frame allocation for context frames (clean latents) + # This determines which frames from history_latents are used as input for the transformer. + available_frames = video_latents.shape[2] if is_start_of_video else history_latents.shape[2] # Use input video frames for first segment, else previously generated history + effective_clean_frames = max(0, num_clean_frames - 1) if num_clean_frames > 1 else 1 + if is_start_of_video: + effective_clean_frames = 1 # Avoid jumpcuts if input video is too different + + clean_latent_pre_frames = effective_clean_frames + num_2x_frames = min(2, max(1, available_frames - clean_latent_pre_frames - 1)) if available_frames > clean_latent_pre_frames + 1 else 1 + num_4x_frames = min(16, max(1, available_frames - clean_latent_pre_frames - num_2x_frames)) if available_frames > clean_latent_pre_frames + num_2x_frames else 1 + total_context_frames = num_2x_frames + num_4x_frames + total_context_frames = min(total_context_frames, available_frames - clean_latent_pre_frames) + + # Prepare indices for the transformer's input (these define the *relative positions* of different frame types in the input tensor) + # The total length is the sum of various frame types: + # clean_latent_pre_frames: frames before the blank/generated section + # latent_padding_size: blank frames before the generated section (for backward generation) + # latent_window_size: the new frames to be generated + # post_frames: frames after the generated section + # num_2x_frames, num_4x_frames: frames for lower resolution context + # 20250511 pftq: Dynamically adjust post_frames based on clean_latents_post + post_frames = 1 if is_end_of_video and end_frame_output_dimensions_latent is not None else effective_clean_frames # 20250511 pftq: Single frame for end_latent, otherwise padding causes still image + indices = torch.arange(0, clean_latent_pre_frames + latent_padding_size + latent_window_size + post_frames + num_2x_frames + num_4x_frames).unsqueeze(0) + clean_latent_indices_pre, blank_indices, latent_indices, clean_latent_indices_post, clean_latent_2x_indices, clean_latent_4x_indices = indices.split( + [clean_latent_pre_frames, latent_padding_size, latent_window_size, post_frames, num_2x_frames, num_4x_frames], dim=1 + ) + clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1) # Combined indices for 1x clean latents + + # Prepare the *actual latent data* for the transformer's context inputs + # These are extracted from history_latents (or video_latents for the first segment) + context_frames = history_latents[:, :, -(total_context_frames + clean_latent_pre_frames):-clean_latent_pre_frames, :, :] if total_context_frames > 0 else history_latents[:, :, :1, :, :] + # clean_latents_4x: 4x downsampled context frames. From history_latents (or video_latents). + # clean_latents_2x: 2x downsampled context frames. From history_latents (or video_latents). + split_sizes = [num_4x_frames, num_2x_frames] + split_sizes = [s for s in split_sizes if s > 0] + if split_sizes and context_frames.shape[2] >= sum(split_sizes): + splits = context_frames.split(split_sizes, dim=2) + split_idx = 0 + clean_latents_4x = splits[split_idx] if num_4x_frames > 0 else history_latents[:, :, :1, :, :] + split_idx += 1 if num_4x_frames > 0 else 0 + clean_latents_2x = splits[split_idx] if num_2x_frames > 0 and split_idx < len(splits) else history_latents[:, :, :1, :, :] + else: + clean_latents_4x = clean_latents_2x = history_latents[:, :, :1, :, :] + + # clean_latents_pre: Latents from the *end* of the input video (if is_start_of_video), or previously generated history. + # Its purpose is to provide a smooth transition *from* the input video. + clean_latents_pre = video_latents[:, :, -min(effective_clean_frames, video_latents.shape[2]):].to(history_latents) + + # clean_latents_post: Latents from the *beginning* of the previously generated video segments. + # Its purpose is to provide a smooth transition *to* the existing generated content. + clean_latents_post = history_latents[:, :, :min(effective_clean_frames, history_latents.shape[2]), :, :] + + # Special handling for the end frame: + # If it's the very first segment being generated (is_end_of_video in terms of generation order), + # and an end_latent was provided, force clean_latents_post to be that end_latent. + if is_end_of_video: + clean_latents_post = torch.zeros_like(end_of_input_video_latent).to(history_latents) # Initialize to zero + + # RT_BORG: end_of_input_video_embedding and end_clip_embedding shouldn't need to be checked, since they should + # always be provided if end_latent is provided. But bulletproofing before the release since test time will be short. + if end_frame_output_dimensions_latent is not None and end_of_input_video_embedding is not None and end_clip_embedding is not None: + # image_encoder_last_hidden_state: Weighted average of CLIP embedding of first input frame and end frame's CLIP embedding + # This guides the overall content to transition towards the end frame. + image_encoder_last_hidden_state = (1 - end_frame_weight) * end_of_input_video_embedding + end_clip_embedding * end_frame_weight + image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(self.transformer.dtype) + + if is_end_of_video: + # For the very first generated segment, the "post" part is the end_latent itself. + clean_latents_post = end_frame_output_dimensions_latent.to(history_latents)[:, :, :1, :, :] # Ensure single frame + + # Pad clean_latents_pre/post if they have fewer frames than specified by clean_latent_pre_frames/post_frames + if clean_latents_pre.shape[2] < clean_latent_pre_frames: + clean_latents_pre = clean_latents_pre.repeat(1, 1, math.ceil(clean_latent_pre_frames / clean_latents_pre.shape[2]), 1, 1)[:,:,:clean_latent_pre_frames] + if clean_latents_post.shape[2] < post_frames: + clean_latents_post = clean_latents_post.repeat(1, 1, math.ceil(post_frames / clean_latents_post.shape[2]), 1, 1)[:,:,:post_frames] + + # clean_latents: Concatenation of pre and post clean latents. These are the 1x resolution context frames. + clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) + + return clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x + + def update_history_latents(self, history_latents, generated_latents): + """ + Backward Generation: Update the history latents with the generated latents for the Video model. + + Args: + history_latents: The history latents + generated_latents: The generated latents + + Returns: + The updated history latents + """ + # For Video model, we prepend the generated latents to the front of history latents + # This matches the original implementation in video-example.py + # It generates new sections backwards in time, chunk by chunk + return torch.cat([generated_latents.to(history_latents), history_latents], dim=2) + + def get_real_history_latents(self, history_latents, total_generated_latent_frames): + """ + Get the real history latents for the backward Video model. For Video, this is the first + `total_generated_latent_frames` frames of the history latents. + + Args: + history_latents: The history latents + total_generated_latent_frames: The total number of generated latent frames + + Returns: + The real history latents + """ + # Generated frames at the front. + return history_latents[:, :, :total_generated_latent_frames, :, :] + + def update_history_pixels(self, history_pixels, current_pixels, overlapped_frames): + """ + Update the history pixels with the current pixels for the Video model. + + Args: + history_pixels: The history pixels + current_pixels: The current pixels + overlapped_frames: The number of overlapped frames + + Returns: + The updated history pixels + """ + from diffusers_helper.utils import soft_append_bcthw + # For Video model, we prepend the current pixels to the history pixels + # This matches the original implementation in video-example.py + return soft_append_bcthw(current_pixels, history_pixels, overlapped_frames) + + def get_current_pixels(self, real_history_latents, section_latent_frames, vae): + """ + Get the current pixels for the Video model. + + Args: + real_history_latents: The real history latents + section_latent_frames: The number of section latent frames + vae: The VAE model + + Returns: + The current pixels + """ + # For backward Video mode, current pixels are at the front of history. + return vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu() + + def format_position_description(self, total_generated_latent_frames, current_pos, original_pos, current_prompt): + """ + Format the position description for the Video model. + + Args: + total_generated_latent_frames: The total number of generated latent frames + current_pos: The current position in seconds (includes input video time) + original_pos: The original position in seconds + current_prompt: The current prompt + + Returns: + The formatted position description + """ + # For Video model, current_pos already includes the input video time + # We just need to display the total generated frames and the current position + return (f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, ' + f'Generated video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30):.2f} seconds (FPS-30). ' + f'Current position: {current_pos:.2f}s (remaining: {original_pos:.2f}s). ' + f'using prompt: {current_prompt[:256]}...') diff --git a/modules/grid_builder.py b/modules/grid_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..31e4e5a3ef8538b23a47fa343cb4ad7816315b87 --- /dev/null +++ b/modules/grid_builder.py @@ -0,0 +1,78 @@ +import os +import cv2 +import numpy as np +import math +from modules.video_queue import JobStatus + +def assemble_grid_video(grid_job, child_jobs, settings): + """ + Assembles a grid video from the results of child jobs. + """ + print(f"Starting grid assembly for job {grid_job.id}") + + output_dir = settings.get("output_dir", "outputs") + os.makedirs(output_dir, exist_ok=True) + + video_paths = [child.result for child in child_jobs if child.status == JobStatus.COMPLETED and child.result and os.path.exists(child.result)] + + if not video_paths: + print(f"No valid video paths found for grid job {grid_job.id}") + return None + + print(f"Found {len(video_paths)} videos for grid assembly.") + + # Determine grid size (e.g., 2x2, 3x3) + num_videos = len(video_paths) + grid_size = math.ceil(math.sqrt(num_videos)) + + # Get video properties from the first video + try: + cap = cv2.VideoCapture(video_paths[0]) + if not cap.isOpened(): + raise IOError(f"Cannot open video file: {video_paths[0]}") + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + cap.release() + except Exception as e: + print(f"Error getting video properties from {video_paths[0]}: {e}") + return None + + output_filename = os.path.join(output_dir, f"grid_{grid_job.id}.mp4") + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter(output_filename, fourcc, fps, (width * grid_size, height * grid_size)) + + caps = [cv2.VideoCapture(p) for p in video_paths] + + while True: + frames = [] + all_frames_read = True + for cap in caps: + ret, frame = cap.read() + if ret: + frames.append(frame) + else: + # If one video ends, stop processing + all_frames_read = False + break + + if not all_frames_read or not frames: + break + + # Create a blank canvas for the grid + grid_frame = np.zeros((height * grid_size, width * grid_size, 3), dtype=np.uint8) + + # Place frames into the grid + for i, frame in enumerate(frames): + row = i // grid_size + col = i % grid_size + grid_frame[row*height:(row+1)*height, col*width:(col+1)*width] = frame + + video_writer.write(grid_frame) + + for cap in caps: + cap.release() + video_writer.release() + + print(f"Grid video saved to {output_filename}") + return output_filename diff --git a/modules/interface.py b/modules/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe30194575a71e98e0d449e36d7a992d9c08e52 --- /dev/null +++ b/modules/interface.py @@ -0,0 +1,2472 @@ +import gradio as gr +import time +import datetime +import random +import json +import os +import shutil +from typing import List, Dict, Any, Optional +from PIL import Image, ImageDraw, ImageFont +import numpy as np +import base64 +import io +import functools + +from modules.version import APP_VERSION, APP_VERSION_DISPLAY + +import subprocess +import itertools +import re +from collections import defaultdict +import imageio +import imageio.plugins.ffmpeg +import ffmpeg +from diffusers_helper.utils import generate_timestamp + +from modules.video_queue import JobStatus, Job, JobType +from modules.prompt_handler import get_section_boundaries, get_quick_prompts, parse_timestamped_prompt +from modules.llm_enhancer import enhance_prompt +from modules.llm_captioner import caption_image +from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html +from diffusers_helper.bucket_tools import find_nearest_bucket +from modules.pipelines.metadata_utils import create_metadata +from modules import DUMMY_LORA_NAME # Import the constant + +from modules.toolbox_app import tb_processor +from modules.toolbox_app import tb_create_video_toolbox_ui, tb_get_formatted_toolbar_stats +from modules.xy_plot_ui import create_xy_plot_ui, xy_plot_process + +# Define the dummy LoRA name as a constant + +def create_interface( + process_fn, + monitor_fn, + end_process_fn, + update_queue_status_fn, + load_lora_file_fn, + job_queue, + settings, + default_prompt: str = '[1s: The person waves hello] [3s: The person jumps up and down] [5s: The person does a dance]', + lora_names: list = [], + lora_values: list = [] +): + """ + Create the Gradio interface for the video generation application + + Args: + process_fn: Function to process a new job + monitor_fn: Function to monitor an existing job + end_process_fn: Function to cancel the current job + update_queue_status_fn: Function to update the queue status display + default_prompt: Default prompt text + lora_names: List of loaded LoRA names + + Returns: + Gradio Blocks interface + """ + def is_video_model(model_type_value): + return model_type_value in ["Video", "Video with Endframe", "Video F1"] + + # Add near the top of create_interface function, after the initial setup + def get_latents_display_top(): + """Get current latents display preference - centralized access point""" + return settings.get("latents_display_top", False) + + def create_latents_layout_update(): + """Create a standardized layout update based on current setting""" + display_top = get_latents_display_top() + if display_top: + return ( + gr.update(visible=True), # top_preview_row + gr.update(visible=False, value=None) # preview_image (right column) + ) + else: + return ( + gr.update(visible=False), # top_preview_row + gr.update(visible=True) # preview_image (right column) + ) + + + + # Get section boundaries and quick prompts + section_boundaries = get_section_boundaries() + quick_prompts = get_quick_prompts() + + # --- Function to update queue stats (Moved earlier to resolve UnboundLocalError) --- + def update_stats(*args): # Accept any arguments and ignore them + # Get queue status data + queue_status_data = update_queue_status_fn() + + # Get queue statistics for the toolbar display + jobs = job_queue.get_all_jobs() + + # Count jobs by status + pending_count = 0 + running_count = 0 + completed_count = 0 + + for job in jobs: + if hasattr(job, 'status'): + status = str(job.status) + if status == "JobStatus.PENDING": + pending_count += 1 + elif status == "JobStatus.RUNNING": + running_count += 1 + elif status == "JobStatus.COMPLETED": + completed_count += 1 + + # Format the queue stats display text + queue_stats_text = f"

Queue: {pending_count} | Running: {running_count} | Completed: {completed_count}

" + + return queue_status_data, queue_stats_text + + # --- Preset System Functions --- + PRESET_FILE = os.path.join(".framepack", "generation_presets.json") + + def load_presets(model_type): + if not os.path.exists(PRESET_FILE): + return [] + with open(PRESET_FILE, 'r') as f: + data = json.load(f) + return list(data.get(model_type, {}).keys()) + + # Create the interface + css = make_progress_bar_css() + css += """ + + .short-import-box, .short-import-box > div { + min-height: 40px !important; + height: 40px !important; + } + /* Image container styling - more aggressive approach */ + .contain-image, .contain-image > div, .contain-image > div > img { + object-fit: contain !important; + } + + #non-mirrored-video { + transform: scaleX(-1) !important; + } + + /* Target all images in the contain-image class and its children */ + .contain-image img, + .contain-image > div > img, + .contain-image * img { + object-fit: contain !important; + width: 100% !important; + height: 60vh !important; + max-height: 100% !important; + max-width: 100% !important; + } + + /* Additional selectors to override Gradio defaults */ + .gradio-container img, + .gradio-container .svelte-1b5oq5x, + .gradio-container [data-testid="image"] img { + object-fit: contain !important; + } + + /* Toolbar styling */ + #fixed-toolbar { + position: fixed; + top: 0; + left: 0; + width: 100vw; + z-index: 1000; + background: #333; + color: #fff; + padding: 0px 10px; /* Reduced top/bottom padding */ + display: flex; + align-items: center; + gap: 8px; + box-shadow: 0 2px 8px rgba(0,0,0,0.1); + } + + /* Responsive toolbar title */ + .toolbar-title { + font-size: 1.4rem; + margin: 0; + color: white; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + + /* Toolbar Patreon link */ + .toolbar-patreon { + margin: 0 0 0 20px; + color: white; + font-size: 0.9rem; + white-space: nowrap; + display: inline-block; + } + .toolbar-patreon a { + color: white; + text-decoration: none; + } + .toolbar-patreon a:hover { + text-decoration: underline; + } + + /* Toolbar Version number */ + .toolbar-version { + margin: 0 15px; /* Space around version */ + color: white; + font-size: 0.8rem; + white-space: nowrap; + display: inline-block; + } + + /* Responsive design for screens */ + @media (max-width: 1147px) { + .toolbar-patreon, .toolbar-version { /* Hide both on smaller screens */ + display: none; + } + .footer-patreon, .footer-version { /* Show both in footer on smaller screens */ + display: inline-block !important; /* Ensure they are shown */ + } + #fixed-toolbar { + gap: 4px !important; /* Reduce gap for screens <= 1024px */ + } + #fixed-toolbar > div:first-child { /* Target the first gr.Column (Title) */ + min-width: fit-content !important; /* Override Python-set min-width */ + flex-shrink: 0 !important; /* Prevent title column from shrinking too much */ + } + } + + @media (min-width: 1148px) { + .footer-patreon, .footer-version { /* Hide both in footer on larger screens */ + display: none !important; + } + } + + @media (max-width: 768px) { + .toolbar-title { + font-size: 1.1rem; + max-width: 150px; + } + #fixed-toolbar { + padding: 3px 6px; + gap: 4px; + } + .toolbar-text { + font-size: 0.75rem; + } + } + + @media (max-width: 510px) { + #toolbar-ram-col, #toolbar-vram-col, #toolbar-gpu-col { + display: none !important; + } + } + + @media (max-width: 480px) { + .toolbar-title { + font-size: 1rem; + max-width: 120px; + } + #fixed-toolbar { + padding: 2px 4px; + gap: 2px; + } + .toolbar-text { + font-size: 0.7rem; + } + } + + /* Button styling */ + #toolbar-add-to-queue-btn button { + font-size: 14px !important; + padding: 4px 16px !important; + height: 32px !important; + min-width: 80px !important; + } + .narrow-button { + min-width: 40px !important; + width: 40px !important; + padding: 0 !important; + margin: 0 !important; + } + .gr-button-primary { + color: white; + } + + /* Layout adjustments */ + body, .gradio-container { + padding-top: 42px !important; /* Adjusted for new toolbar height (36px - 10px) */ + } + + @media (max-width: 848px) { + body, .gradio-container { + padding-top: 48px !important; + } + } + + @media (max-width: 768px) { + body, .gradio-container { + padding-top: 22px !important; /* Adjusted for new toolbar height (32px - 10px) */ + } + } + + @media (max-width: 480px) { + body, .gradio-container { + padding-top: 18px !important; /* Adjusted for new toolbar height (28px - 10px) */ + } + } + + /* hide the gr.Video source selection bar for tb_input_video_component */ + #toolbox-video-player .source-selection { + display: none !important; + } + /* control sizing for gr.Video components */ + .video-size video { + max-height: 60vh; + min-height: 300px !important; + object-fit: contain; + } + /* NEW: Closes the gap between input tabs and the pipeline accordion below them */ + #pipeline-controls-wrapper { + margin-top: -15px !important; /* Adjust this value to get the perfect "snug" fit */ + } + /* --- NEW CSS RULE FOR GALLERY SCROLLING --- */ + #gallery-scroll-wrapper { + max-height: 600px; /* Set your desired fixed height */ + overflow-y: auto; /* Add a scrollbar only when needed */ + } + #toolbox-start-pipeline-btn { + margin-top: -14px !important; /* Adjust this value to get the perfect alignment */ + } + + .control-group { + border-top: 1px solid #ccc; + border-bottom: 1px solid #ccc; + margin: 12px 0; + } + """ + + # Get the theme from settings + current_theme = settings.get("gradio_theme", "default") # Use default if not found + block = gr.Blocks(css=css, title="FramePack Studio", theme=current_theme).queue() + + with block: + with gr.Row(elem_id="fixed-toolbar"): + with gr.Column(scale=0, min_width=400): # Title/Version/Patreon + gr.HTML(f""" +
+

FP Studio

+

{APP_VERSION_DISPLAY}

+

Support on Patreon

+
+ """) + # REMOVED: refresh_stats_btn - Toolbar refresh button is no longer needed + # with gr.Column(scale=0, min_width=40): + # refresh_stats_btn = gr.Button("⟳", elem_id="refresh-stats-btn", elem_classes="narrow-button") + with gr.Column(scale=1, min_width=180): # Queue Stats + queue_stats_display = gr.Markdown("

Queue: 0 | Running: 0 | Completed: 0

") + + # --- System Stats Display - Single gr.Textbox per stat --- + with gr.Column(scale=0, min_width=173, elem_id="toolbar-ram-col"): # RAM Column + toolbar_ram_display_component = gr.Textbox( + value="RAM: N/A", + interactive=False, + lines=1, + max_lines=1, + show_label=False, + container=False, + elem_id="toolbar-ram-stat", + elem_classes="toolbar-stat-textbox" + ) + with gr.Column(scale=0, min_width=138, elem_id="toolbar-vram-col"): # VRAM Column + toolbar_vram_display_component = gr.Textbox( + value="VRAM: N/A", + interactive=False, + lines=1, + max_lines=1, + show_label=False, + container=False, + elem_id="toolbar-vram-stat", + elem_classes="toolbar-stat-textbox" + # Visibility controlled by tb_get_formatted_toolbar_stats + ) + with gr.Column(scale=0, min_width=130, elem_id="toolbar-gpu-col"): # GPU Column + toolbar_gpu_display_component = gr.Textbox( + value="GPU: N/A", + interactive=False, + lines=1, + max_lines=1, + show_label=False, + container=False, + elem_id="toolbar-gpu-stat", + elem_classes="toolbar-stat-textbox" + # Visibility controlled by tb_get_formatted_toolbar_stats + ) + # --- End of System Stats Display --- + + # Removed old version_display column + # --- End of Toolbar --- + + # Essential to capture main_tabs_component for later use by send_to_toolbox_btn + with gr.Tabs(elem_id="main_tabs") as main_tabs_component: + with gr.Tab("Generate", id="generate_tab"): + # NEW: Top preview area for latents display + with gr.Row(visible=get_latents_display_top()) as top_preview_row: + top_preview_image = gr.Image( + label="Next Latents (Top Display)", + height=150, + visible=True, + type="numpy", + interactive=False, + elem_classes="contain-image", + image_mode="RGB" + ) + + with gr.Row(): + with gr.Column(scale=2): + model_type = gr.Radio( + choices=[("Original", "Original"), ("Original with Endframe", "Original with Endframe"), ("F1", "F1"), ("Video", "Video"), ("Video with Endframe", "Video with Endframe"), ("Video F1", "Video F1")], + value="Original", + label="Generation Type" + ) + with gr.Accordion("Original Presets", open=False, visible=True) as preset_accordion: + with gr.Row(): + preset_dropdown = gr.Dropdown(label="Select Preset", choices=load_presets("Original"), interactive=True, scale=2) + delete_preset_button = gr.Button("🗑️ Delete", variant="stop", scale=1) + with gr.Row(): + preset_name_textbox = gr.Textbox(label="Preset Name", placeholder="Enter a name for your preset", scale=2) + save_preset_button = gr.Button("💾 Save", variant="primary", scale=1) + with gr.Row(visible=False) as confirm_delete_row: + gr.Markdown("### Are you sure you want to delete this preset?") + confirm_delete_yes_btn = gr.Button("🗑️ Yes, Delete", variant="stop") + confirm_delete_no_btn = gr.Button("↩️ No, Go Back") + with gr.Accordion("Basic Parameters", open=True, visible=True) as basic_parameters_accordion: + with gr.Group(): + total_second_length = gr.Slider(label="Video Length (Seconds)", minimum=1, maximum=120, value=6, step=0.1) + with gr.Row("Resolution"): + resolutionW = gr.Slider( + label="Width", minimum=128, maximum=768, value=640, step=32, + info="Nearest valid width will be used." + ) + resolutionH = gr.Slider( + label="Height", minimum=128, maximum=768, value=640, step=32, + info="Nearest valid height will be used." + ) + resolution_text = gr.Markdown(value="
Selected bucket for resolution: 640 x 640
", label="", show_label=False) + + # --- START OF REFACTORED XY PLOT SECTION --- + xy_plot_components = create_xy_plot_ui( + lora_names=lora_names, + default_prompt=default_prompt, + DUMMY_LORA_NAME=DUMMY_LORA_NAME, + ) + xy_group = xy_plot_components["group"] + xy_plot_status = xy_plot_components["status"] + xy_plot_output = xy_plot_components["output"] + # --- END OF REFACTORED XY PLOT SECTION --- + + with gr.Group(visible=True) as standard_generation_group: # Default visibility: True because "Original" model is not "Video" + with gr.Group(visible=True) as image_input_group: # This group now only contains the start frame image + with gr.Row(): + with gr.Column(scale=1): # Start Frame Image Column + input_image = gr.Image( + sources='upload', + type="numpy", + label="Start Frame (optional)", + elem_classes="contain-image", + image_mode="RGB", + show_download_button=False, + show_label=True, # Keep label for clarity + container=True + ) + + with gr.Group(visible=False) as video_input_group: + input_video = gr.Video( + sources='upload', + label="Video Input", + height=420, + show_label=True + ) + combine_with_source = gr.Checkbox( + label="Combine with source video", + value=True, + info="If checked, the source video will be combined with the generated video", + interactive=True + ) + num_cleaned_frames = gr.Slider(label="Number of Context Frames (Adherence to Video)", minimum=2, maximum=10, value=5, step=1, interactive=True, info="Expensive. Retain more video details. Reduce if memory issues or motion too restricted (jumpcut, ignoring prompt, still).") + + + # End Frame Image Input + # Initial visibility is False, controlled by update_input_visibility + with gr.Column(scale=1, visible=False) as end_frame_group_original: + end_frame_image_original = gr.Image( + sources='upload', + type="numpy", + label="End Frame (Optional)", + elem_classes="contain-image", + image_mode="RGB", + show_download_button=False, + show_label=True, + container=True + ) + + # End Frame Influence slider + # Initial visibility is False, controlled by update_input_visibility + with gr.Group(visible=False) as end_frame_slider_group: + end_frame_strength_original = gr.Slider( + label="End Frame Influence", + minimum=0.05, + maximum=1.0, + value=1.0, + step=0.05, + info="Controls how strongly the end frame guides the generation. 1.0 is full influence." + ) + + + + with gr.Row(): + prompt = gr.Textbox(label="Prompt", value=default_prompt, scale=10) + with gr.Row(): + enhance_prompt_btn = gr.Button("✨ Enhance", scale=1) + caption_btn = gr.Button("✨ Caption", scale=1) + + with gr.Accordion("Prompt Parameters", open=False): + n_prompt = gr.Textbox(label="Negative Prompt", value="", visible=True) # Make visible for both models + + blend_sections = gr.Slider( + minimum=0, maximum=10, value=4, step=1, + label="Number of sections to blend between prompts" + ) + with gr.Accordion("Batch Input", open=False): + batch_input_images = gr.File( + label="Batch Images (Upload one or more)", + file_count="multiple", + file_types=["image"], + type="filepath" + ) + batch_input_gallery = gr.Gallery( + label="Selected Batch Images", + visible=False, + columns=5, + object_fit="contain", + height="auto" + ) + add_batch_to_queue_btn = gr.Button("🚀 Add Batch to Queue", variant="primary") + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=25, step=1) + def on_input_image_change(img): + if img is not None: + return gr.update(info="Nearest valid bucket size will be used. Height will be adjusted automatically."), gr.update(visible=False) + else: + return gr.update(info="Nearest valid width will be used."), gr.update(visible=True) + input_image.change(fn=on_input_image_change, inputs=[input_image], outputs=[resolutionW, resolutionH]) + def on_resolution_change(img, resolutionW, resolutionH): + out_bucket_resH, out_bucket_resW = [640, 640] + if img is not None: + H, W, _ = img.shape + out_bucket_resH, out_bucket_resW = find_nearest_bucket(H, W, resolution=resolutionW) + else: + out_bucket_resH, out_bucket_resW = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) # if resolutionW > resolutionH else resolutionH + return gr.update(value=f"
Selected bucket for resolution: {out_bucket_resW} x {out_bucket_resH}
") + resolutionW.change(fn=on_resolution_change, inputs=[input_image, resolutionW, resolutionH], outputs=[resolution_text], show_progress="hidden") + resolutionH.change(fn=on_resolution_change, inputs=[input_image, resolutionW, resolutionH], outputs=[resolution_text], show_progress="hidden") + + with gr.Row(): + seed = gr.Number(label="Seed", value=2500, precision=0) + randomize_seed = gr.Checkbox(label="Randomize", value=True, info="Generate a new random seed for each job") + with gr.Accordion("LoRAs", open=False): + with gr.Row(): + lora_selector = gr.Dropdown( + choices=lora_names, + label="Select LoRAs to Load", + multiselect=True, + value=[], + info="Select one or more LoRAs to use for this job" + ) + lora_names_states = gr.State(lora_names) + lora_sliders = {} + for lora in lora_names: + lora_sliders[lora] = gr.Slider( + minimum=0.0, maximum=2.0, value=1.0, step=0.01, + label=f"{lora} Weight", visible=False, interactive=True + ) + with gr.Accordion("Latent Image Options", open=False): + latent_type = gr.Dropdown( + ["Noise", "White", "Black", "Green Screen"], label="Latent Image", value="Noise", info="Used as a starting point if no image is provided" + ) + with gr.Accordion("Advanced Parameters", open=False): + gr.Markdown("#### Motion Model") + gr.Markdown("Settings for precise control of the motion model") + + with gr.Group(elem_classes="control-group"): + latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, info='Change at your own risk, very experimental') # Should not change + gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.5) + + gr.Markdown("#### CFG Scale") + gr.Markdown("Much better prompt following. Warning: Modifying these values from their defaults will almost double generation time. ⚠️") + + with gr.Group(elem_classes="control-group"): + cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=3.0, value=1.0, step=0.1) + rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.05) + + gr.Markdown("#### Cache Options") + gr.Markdown("Using a cache will speed up generation. May affect quality, fine or even coarse details, and may change or inhibit motion. You can choose at most one.") + + with gr.Group(elem_classes="control-group"): + with gr.Row(): + cache_type = gr.Radio(["MagCache", "TeaCache", "None"], value='MagCache', label="Caching strategy", info="Which cache implementation to use, if any") + + with gr.Row(): # MagCache now first + magcache_threshold = gr.Slider(label="MagCache Threshold", minimum=0.01, maximum=1.0, step=0.01, value=0.1, visible=True, info='[⬇️ **Faster**] Error tolerance. Lower = more estimated steps') + magcache_max_consecutive_skips = gr.Slider(label="MagCache Max Consecutive Skips", minimum=1, maximum=5, step=1, value=2, visible=True, info='[⬆️ **Faster**] Allow multiple estimated steps in a row') + magcache_retention_ratio = gr.Slider(label="MagCache Retention Ratio", minimum=0.0, maximum=1.0, step=0.01, value=0.25, visible=True, info='[⬇️ **Faster**] Disallow estimation in critical early steps') + + with gr.Row(): + teacache_num_steps = gr.Slider(label="TeaCache steps", minimum=1, maximum=50, step=1, value=25, visible=False, info='How many intermediate sections to keep in the cache') + teacache_rel_l1_thresh = gr.Slider(label="TeaCache rel_l1_thresh", minimum=0.01, maximum=1.0, step=0.01, value=0.15, visible=False, info='[⬇️ **Faster**] Relative L1 Threshold') + + def update_cache_type(cache_type: str): + enable_magcache = False + enable_teacache = False + + if cache_type == 'MagCache': + enable_magcache = True + elif cache_type == 'TeaCache': + enable_teacache = True + + magcache_threshold_update = gr.update(visible=enable_magcache) + magcache_max_consecutive_skips_update = gr.update(visible=enable_magcache) + magcache_retention_ratio_update = gr.update(visible=enable_magcache) + + teacache_num_steps_update = gr.update(visible=enable_teacache) + teacache_rel_l1_thresh_update = gr.update(visible=enable_teacache) + + return [ + magcache_threshold_update, + magcache_max_consecutive_skips_update, + magcache_retention_ratio_update, + teacache_num_steps_update, + teacache_rel_l1_thresh_update + ] + + + cache_type.change(fn=update_cache_type, inputs=cache_type, outputs=[ + magcache_threshold, + magcache_max_consecutive_skips, + magcache_retention_ratio, + teacache_num_steps, + teacache_rel_l1_thresh + ]) + + with gr.Row("Metadata"): + json_upload = gr.File( + label="Upload Metadata JSON (optional)", + file_types=[".json"], + type="filepath", + height=140, + ) + + with gr.Column(): + preview_image = gr.Image( + label="Next Latents", + height=150, + visible=not get_latents_display_top(), + type="numpy", + interactive=False, + elem_classes="contain-image", + image_mode="RGB" + ) + result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=256, loop=True) + progress_desc = gr.Markdown('', elem_classes='no-generating-animation') + progress_bar = gr.HTML('', elem_classes='no-generating-animation') + with gr.Row(): + current_job_id = gr.Textbox(label="Current Job ID", value="", visible=True, interactive=True) + start_button = gr.Button(value="🚀 Add to Queue", variant="primary", elem_id="toolbar-add-to-queue-btn") + xy_plot_process_btn = gr.Button("🚀 Submit XY Plot", visible=False) + video_input_required_message = gr.Markdown( + "

Input video required

", visible=False + ) + end_button = gr.Button(value="❌ Cancel Current Job", interactive=True, visible=False) + + + + with gr.Tab("Queue"): + with gr.Row(): + with gr.Column(): + with gr.Row() as queue_controls_row: + refresh_button = gr.Button("🔄 Refresh Queue") + load_queue_button = gr.Button("▶️ Resume Queue") + queue_export_button = gr.Button("📦 Export Queue") + clear_complete_button = gr.Button("🧹 Clear Completed Jobs", variant="secondary") + clear_queue_button = gr.Button("❌ Cancel Queued Jobs", variant="stop") + with gr.Row(): + import_queue_file = gr.File( + label="Import Queue", + file_types=[".json", ".zip"], + type="filepath", + visible=True, + elem_classes="short-import-box" + ) + + with gr.Row(visible=False) as confirm_cancel_row: + gr.Markdown("### Are you sure you want to cancel all pending jobs?") + confirm_cancel_yes_btn = gr.Button("❌ Yes, Cancel All", variant="stop") + confirm_cancel_no_btn = gr.Button("↩️ No, Go Back") + + with gr.Row(): + queue_status = gr.DataFrame( + headers=["Job ID", "Type", "Status", "Created", "Started", "Completed", "Elapsed", "Preview"], + datatype=["str", "str", "str", "str", "str", "str", "str", "html"], + label="Job Queue" + ) + + with gr.Accordion("Queue Documentation", open=False): + gr.Markdown(""" + ## Queue Tab Guide + + This tab is for managing your generation jobs. + + - **Refresh Queue**: Update the job list. + - **Cancel Queue**: Stop all pending jobs. + - **Clear Complete**: Remove finished, failed, or cancelled jobs from the list. + - **Load Queue**: Load jobs from the default `queue.json`. + - **Export Queue**: Save the current job list and its images to a zip file. + - **Import Queue**: Load a queue from a `.json` or `.zip` file. + """) + + # --- Event Handlers for Queue Tab --- + + # Function to clear all jobs in the queue + def clear_all_jobs(): + try: + cancelled_count = job_queue.clear_queue() + print(f"Cleared {cancelled_count} jobs from the queue") + return update_stats() + except Exception as e: + import traceback + print(f"Error in clear_all_jobs: {e}") + traceback.print_exc() + return [], "" + + # Function to clear completed and cancelled jobs + def clear_completed_jobs(): + try: + removed_count = job_queue.clear_completed_jobs() + print(f"Removed {removed_count} completed/cancelled jobs from the queue") + return update_stats() + except Exception as e: + import traceback + print(f"Error in clear_completed_jobs: {e}") + traceback.print_exc() + return [], "" + + # Function to load queue from queue.json + def load_queue_from_json(): + try: + loaded_count = job_queue.load_queue_from_json() + print(f"Loaded {loaded_count} jobs from queue.json") + return update_stats() + except Exception as e: + import traceback + print(f"Error loading queue from JSON: {e}") + traceback.print_exc() + return [], "" + + # Function to import queue from a custom JSON file + def import_queue_from_file(file_path): + if not file_path: + return update_stats() + try: + loaded_count = job_queue.load_queue_from_json(file_path) + print(f"Loaded {loaded_count} jobs from {file_path}") + return update_stats() + except Exception as e: + import traceback + print(f"Error importing queue from file: {e}") + traceback.print_exc() + return [], "" + + # Function to export queue to a zip file + def export_queue_to_zip(): + try: + zip_path = job_queue.export_queue_to_zip() + if zip_path and os.path.exists(zip_path): + print(f"Queue exported to {zip_path}") + else: + print("Failed to export queue to zip") + return update_stats() + except Exception as e: + import traceback + print(f"Error exporting queue to zip: {e}") + traceback.print_exc() + return [], "" + + # --- Connect Buttons --- + refresh_button.click(fn=update_stats, inputs=[], outputs=[queue_status, queue_stats_display]) + + # Confirmation logic for Cancel Queue + def show_cancel_confirmation(): + return gr.update(visible=False), gr.update(visible=True) + + def hide_cancel_confirmation(): + return gr.update(visible=True), gr.update(visible=False) + + def confirmed_clear_all_jobs(): + qs_data, qs_text = clear_all_jobs() + return qs_data, qs_text, gr.update(visible=True), gr.update(visible=False) + + clear_queue_button.click(fn=show_cancel_confirmation, inputs=None, outputs=[queue_controls_row, confirm_cancel_row]) + confirm_cancel_no_btn.click(fn=hide_cancel_confirmation, inputs=None, outputs=[queue_controls_row, confirm_cancel_row]) + confirm_cancel_yes_btn.click(fn=confirmed_clear_all_jobs, inputs=None, outputs=[queue_status, queue_stats_display, queue_controls_row, confirm_cancel_row]) + + clear_complete_button.click(fn=clear_completed_jobs, inputs=[], outputs=[queue_status, queue_stats_display]) + queue_export_button.click(fn=export_queue_to_zip, inputs=[], outputs=[queue_status, queue_stats_display]) + + # Create a container for thumbnails (kept for potential future use, though not displayed in DataFrame) + with gr.Row(): + thumbnail_container = gr.Column() + thumbnail_container.elem_classes = ["thumbnail-container"] + + # Add CSS for thumbnails + + with gr.Tab("Outputs", id="outputs_tab"): # Ensure 'id' is present for tab switching + outputDirectory_video = settings.get("output_dir", settings.default_settings['output_dir']) + outputDirectory_metadata = settings.get("metadata_dir", settings.default_settings['metadata_dir']) + def get_gallery_items(): + items = [] + for f in os.listdir(outputDirectory_metadata): + if f.endswith(".png"): + prefix = os.path.splitext(f)[0] + latest_video = get_latest_video_version(prefix) + if latest_video: + video_path = os.path.join(outputDirectory_video, latest_video) + mtime = os.path.getmtime(video_path) + preview_path = os.path.join(outputDirectory_metadata, f) + items.append((preview_path, prefix, mtime)) + items.sort(key=lambda x: x[2], reverse=True) + return [(i[0], i[1]) for i in items] + def get_latest_video_version(prefix): + max_number = -1 + selected_file = None + for f in os.listdir(outputDirectory_video): + if f.startswith(prefix + "_") and f.endswith(".mp4"): + # Skip files that include "combined" in their name + if "combined" in f: + continue + try: + num = int(f.replace(prefix + "_", '').replace(".mp4", '')) + if num > max_number: + max_number = num + selected_file = f + except ValueError: + # Ignore files that do not have a valid number in their name + continue + return selected_file + # load_video_and_info_from_prefix now also returns button visibility + def load_video_and_info_from_prefix(prefix): + video_file = get_latest_video_version(prefix) + json_path = os.path.join(outputDirectory_metadata, prefix) + ".json" + + if not video_file or not os.path.exists(os.path.join(outputDirectory_video, video_file)) or not os.path.exists(json_path): + # If video or info not found, button should be hidden + return None, "Video or JSON not found.", gr.update(visible=False) + + video_path = os.path.join(outputDirectory_video, video_file) + info_content = {"description": "no info"} + if os.path.exists(json_path): + with open(json_path, "r", encoding="utf-8") as f: + info_content = json.load(f) + # If video and info found, button should be visible + return video_path, json.dumps(info_content, indent=2, ensure_ascii=False), gr.update(visible=True) + + gallery_items_state = gr.State(get_gallery_items()) + selected_original_video_path_state = gr.State(None) # Holds the ORIGINAL, UNPROCESSED path + with gr.Row(): + with gr.Column(scale=2): + thumbs = gr.Gallery( + # value=[i[0] for i in get_gallery_items()], + columns=[4], + allow_preview=False, + object_fit="cover", + height="auto" + ) + refresh_button = gr.Button("🔄 Update Gallery") + with gr.Column(scale=5): + video_out = gr.Video(sources=[], autoplay=True, loop=True, visible=False) + with gr.Column(scale=1): + info_out = gr.Textbox(label="Generation info", visible=False) + send_to_toolbox_btn = gr.Button("➡️ Send to Post-processing", visible=False) # Added new send_to_toolbox_btn + def refresh_gallery(): + new_items = get_gallery_items() + return gr.update(value=[i[0] for i in new_items]), new_items + refresh_button.click(fn=refresh_gallery, outputs=[thumbs, gallery_items_state]) + + # MODIFIED: on_select now handles visibility of the new button + def on_select(evt: gr.SelectData, gallery_items): + if evt.index is None or not gallery_items or evt.index >= len(gallery_items): + return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), None + + prefix = gallery_items[evt.index][1] + # original_video_path is e.g., "outputs/my_actual_video.mp4" + original_video_path, info_string, button_visibility_update = load_video_and_info_from_prefix(prefix) + + # Determine visibility for video and info based on whether video_path was found + video_out_update = gr.update(value=original_video_path, visible=bool(original_video_path)) + info_out_update = gr.update(value=info_string, visible=bool(original_video_path)) + + # IMPORTANT: Store the ORIGINAL, UNPROCESSED path in the gr.State + return video_out_update, info_out_update, button_visibility_update, original_video_path + + thumbs.select( + fn=on_select, + inputs=[gallery_items_state], + outputs=[video_out, info_out, send_to_toolbox_btn, selected_original_video_path_state] # Output original path to State + ) + with gr.Tab("Post-processing", id="toolbox_tab"): + # Call the function from toolbox_app.py to build the Toolbox UI + # The toolbox_ui_layout (e.g., a gr.Column) is automatically placed here. + toolbox_ui_layout, tb_target_video_input = tb_create_video_toolbox_ui() + + with gr.Tab("Settings"): + with gr.Row(): + with gr.Column(): + save_metadata = gr.Checkbox( + label="Save Metadata", + info="Save to JSON file", + value=settings.get("save_metadata", 6), + ) + gpu_memory_preservation = gr.Slider( + label="Memory Buffer for Stability (VRAM GB)", + minimum=1, + maximum=128, + step=0.1, + value=settings.get("gpu_memory_preservation", 6), + info="Increase reserve if you see computer freezes, stagnant generation, or super slow sampling steps (try 1G at a time).\ + Otherwise smaller buffer is faster. Some models and lora need more buffer than others. \ + (5.5 - 8.5 is a common range)" + ) + mp4_crf = gr.Slider( + label="MP4 Compression", + minimum=0, + maximum=100, + step=1, + value=settings.get("mp4_crf", 16), + info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs." + ) + clean_up_videos = gr.Checkbox( + label="Clean up video files", + value=settings.get("clean_up_videos", True), + info="If checked, only the final video will be kept after generation." + ) + auto_cleanup_on_startup = gr.Checkbox( + label="Automatically clean up temp folders on startup", + value=settings.get("auto_cleanup_on_startup", False), + info="If checked, temporary files (inc. post-processing) will be cleaned up when the application starts." + ) + + latents_display_top = gr.Checkbox( + label="Display Next Latents across top of interface", + value=get_latents_display_top(), + info="If checked, the Next Latents preview will be displayed across the top of the interface instead of in the right column." + ) + + # gr.Markdown("---") + # gr.Markdown("### Startup Settings") + gr.Markdown("") + # Initial values for startup preset dropdown + # Ensure settings and load_presets are available in this scope + initial_startup_model_val = settings.get("startup_model_type", "None") + initial_startup_presets_choices_val = [] + initial_startup_preset_value_val = None + + if initial_startup_model_val and initial_startup_model_val != "None": + # load_presets is defined further down in create_interface + initial_startup_presets_choices_val = load_presets(initial_startup_model_val) + saved_preset_for_initial_model_val = settings.get("startup_preset_name") + if saved_preset_for_initial_model_val in initial_startup_presets_choices_val: + initial_startup_preset_value_val = saved_preset_for_initial_model_val + + startup_model_type_dropdown = gr.Dropdown( + label="Startup Model Type", + choices=["None"] + [choice[0] for choice in model_type.choices if choice[0] != "XY Plot"], # model_type is the Radio on Generate tab + value=initial_startup_model_val, + info="Select a model type to load on startup. 'None' to disable." + ) + startup_preset_name_dropdown = gr.Dropdown( + label="Startup Preset", + choices=initial_startup_presets_choices_val, + value=initial_startup_preset_value_val, + info="Select a preset for the startup model. Updates when Startup Model Type changes.", + interactive=True # Must be interactive to be updated by another component + ) + + with gr.Accordion("System Prompt", open=False): + with gr.Row(equal_height=True): # New Row to contain checkbox and reset button + override_system_prompt = gr.Checkbox( + label="Override System Prompt", + value=settings.get("override_system_prompt", False), + info="If checked, the system prompt template below will be used instead of the default one.", + scale=1 # Give checkbox some scale + ) + reset_system_prompt_btn = gr.Button( + "🔄 Reset", + scale=0 + ) + system_prompt_template = gr.Textbox( + label="System Prompt Template", + value=settings.get("system_prompt_template", "{\"template\": \"<|start_header_id|>system<|end_header_id|>\\n\\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|>\", \"crop_start\": 95}"), + lines=10, + info="System prompt template used for video generation. Must be a valid JSON or Python dictionary string with 'template' and 'crop_start' keys. Example: {\"template\": \"your template here\", \"crop_start\": 95}" + ) + # The reset_system_prompt_btn is now defined above within the Row + + # --- Settings Tab Event Handlers --- + + output_dir = gr.Textbox( + label="Output Directory", + value=settings.get("output_dir"), + placeholder="Path to save generated videos" + ) + metadata_dir = gr.Textbox( + label="Metadata Directory", + value=settings.get("metadata_dir"), + placeholder="Path to save metadata files" + ) + lora_dir = gr.Textbox( + label="LoRA Directory", + value=settings.get("lora_dir"), + placeholder="Path to LoRA models" + ) + gradio_temp_dir = gr.Textbox(label="Gradio Temporary Directory", value=settings.get("gradio_temp_dir")) + auto_save = gr.Checkbox( + label="Auto-save settings", + value=settings.get("auto_save_settings", True) + ) + # Add Gradio Theme Dropdown + gradio_themes = ["default", "base", "soft", "glass", "mono", "origin", "citrus", "monochrome", "ocean", "NoCrypt/miku", "earneleh/paris", "gstaff/xkcd"] + theme_dropdown = gr.Dropdown( + label="Theme", + choices=gradio_themes, + value=settings.get("gradio_theme", "default"), + info="Select the Gradio UI theme. Requires restart." + ) + save_btn = gr.Button("💾 Save Settings") + cleanup_btn = gr.Button("🗑️ Clean Up Temporary Files") + status = gr.HTML("") + cleanup_output = gr.Textbox(label="Cleanup Status", interactive=False) + + def save_settings(save_metadata, gpu_memory_preservation, mp4_crf, clean_up_videos, auto_cleanup_on_startup_val, latents_display_top_val, override_system_prompt_value, system_prompt_template_value, output_dir, metadata_dir, lora_dir, gradio_temp_dir, auto_save, selected_theme, startup_model_type_val, startup_preset_name_val): + """Handles the manual 'Save Settings' button click.""" + # This function is for the manual save button. + # It collects all current UI values and saves them. + # The auto-save logic is handled by individual .change() and .blur() handlers + # calling settings.set(). + + # First, update the settings object with all current values from the UI + try: + # Save the system prompt template as is, without trying to parse it + # The hunyuan.py file will handle parsing it when needed + processed_template = system_prompt_template_value + + settings.save_settings( + save_metadata=save_metadata, + gpu_memory_preservation=gpu_memory_preservation, + mp4_crf=mp4_crf, + clean_up_videos=clean_up_videos, + auto_cleanup_on_startup=auto_cleanup_on_startup_val, # ADDED + latents_display_top=latents_display_top_val, # NEW: Added latents display position setting + override_system_prompt=override_system_prompt_value, + system_prompt_template=processed_template, + output_dir=output_dir, + metadata_dir=metadata_dir, + lora_dir=lora_dir, + gradio_temp_dir=gradio_temp_dir, + auto_save_settings=auto_save, + gradio_theme=selected_theme, + startup_model_type=startup_model_type_val, + startup_preset_name=startup_preset_name_val + ) + # settings.save_settings() is called inside settings.save_settings if auto_save is true, + # but for the manual button, we ensure it saves regardless of the auto_save flag's previous state. + # The call above to settings.save_settings already handles writing to disk. + return "

Settings saved successfully! Restart required for theme change.

" + except Exception as e: + return f"

Error saving settings: {str(e)}

" + + def handle_individual_setting_change(key, value, setting_name_for_ui): + """Called by .change() and .submit() events of individual setting components.""" + if key == "auto_save_settings": + # For the "auto_save_settings" checkbox itself: + # 1. Update its value directly in the settings object in memory. + # This bypasses the conditional save logic within settings.set() for this specific action. + settings.settings[key] = value + # 2. Force a save of all settings to disk. This will be correct because either: + # - auto_save_settings is turning True: so all changes already in memory need to be saved now. + # - auto_save_settings turning False from True: prior changes already saved so only auto_save_settings will be saved. + settings.save_settings() + # 3. Provide feedback. + if value is True: + return f"

'{setting_name_for_ui}' setting is now ON and saved.

" + else: + return f"

'{setting_name_for_ui}' setting is now OFF and saved.

" + else: + # For all other settings: + # Let settings.set() handle the auto-save logic based on the current "auto_save_settings" value. + settings.set(key, value) # settings.set() will call save_settings() if auto_save is True + if settings.get("auto_save_settings"): # Check the current state of auto_save + return f"

'{setting_name_for_ui}' setting auto-saved.

" + else: + return f"

'{setting_name_for_ui}' setting changed (auto-save is off, click 'Save Settings').

" + + # REMOVE `cleanup_temp_folder` from the `inputs` list + save_btn.click( + fn=save_settings, + inputs=[save_metadata, gpu_memory_preservation, mp4_crf, clean_up_videos, auto_cleanup_on_startup, latents_display_top, override_system_prompt, system_prompt_template, output_dir, metadata_dir, lora_dir, gradio_temp_dir, auto_save, theme_dropdown, startup_model_type_dropdown, startup_preset_name_dropdown], + outputs=[status] + ).then( + # NEW: Update latents display layout after manual save + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + def reset_system_prompt_template_value(): + return settings.default_settings["system_prompt_template"], False + + reset_system_prompt_btn.click( + fn=reset_system_prompt_template_value, + outputs=[system_prompt_template, override_system_prompt] + ).then( # Trigger auto-save for the reset values if auto-save is on + lambda val_template, val_override: handle_individual_setting_change("system_prompt_template", val_template, "System Prompt Template") or handle_individual_setting_change("override_system_prompt", val_override, "Override System Prompt"), + inputs=[system_prompt_template, override_system_prompt], outputs=[status]) + + def manual_cleanup_handler(): + """UI handler for the manual cleanup button.""" + # This directly calls the toolbox_processor's cleanup method and returns the summary string. + summary = tb_processor.tb_clear_temporary_files() + return summary + + cleanup_btn.click( + fn=manual_cleanup_handler, + inputs=None, + outputs=[cleanup_output] + ) + + # Add .change handlers for auto-saving individual settings + save_metadata.change(lambda v: handle_individual_setting_change("save_metadata", v, "Save Metadata"), inputs=[save_metadata], outputs=[status]) + gpu_memory_preservation.change(lambda v: handle_individual_setting_change("gpu_memory_preservation", v, "GPU Memory Preservation"), inputs=[gpu_memory_preservation], outputs=[status]) + mp4_crf.change(lambda v: handle_individual_setting_change("mp4_crf", v, "MP4 Compression"), inputs=[mp4_crf], outputs=[status]) + clean_up_videos.change(lambda v: handle_individual_setting_change("clean_up_videos", v, "Clean Up Videos"), inputs=[clean_up_videos], outputs=[status]) + + # NEW: auto-cleanup temp files on startup checkbox + auto_cleanup_on_startup.change(lambda v: handle_individual_setting_change("auto_cleanup_on_startup", v, "Auto Cleanup on Startup"), inputs=[auto_cleanup_on_startup], outputs=[status]) + + # NEW: latents display position setting + latents_display_top.change(lambda v: handle_individual_setting_change("latents_display_top", v, "Latents Display Position"), inputs=[latents_display_top], outputs=[status]) + + + + # Connect the latents display setting to layout updates + def update_latents_display_layout_from_checkbox(display_top): + """Update layout when checkbox changes - uses the checkbox value directly""" + if display_top: + return ( + gr.update(visible=True), # top_preview_row + gr.update(visible=False, value=None) # preview_image (right column) + ) + else: + return ( + gr.update(visible=False), # top_preview_row + gr.update(visible=True) # preview_image (right column) + ) + + latents_display_top.change( + fn=update_latents_display_layout_from_checkbox, + inputs=[latents_display_top], + outputs=[top_preview_row, preview_image] + ) + + override_system_prompt.change(lambda v: handle_individual_setting_change("override_system_prompt", v, "Override System Prompt"), inputs=[override_system_prompt], outputs=[status]) + # Using .blur for text changes so they are processed after the user finishes, not on every keystroke + system_prompt_template.blur(lambda v: handle_individual_setting_change("system_prompt_template", v, "System Prompt Template"), inputs=[system_prompt_template], outputs=[status]) + # reset_system_prompt_btn # is handled separately above, on click + + # Using .blur for text changes so they are processed after the user finishes, not on every keystroke + output_dir.blur(lambda v: handle_individual_setting_change("output_dir", v, "Output Directory"), inputs=[output_dir], outputs=[status]) + metadata_dir.blur(lambda v: handle_individual_setting_change("metadata_dir", v, "Metadata Directory"), inputs=[metadata_dir], outputs=[status]) + lora_dir.blur(lambda v: handle_individual_setting_change("lora_dir", v, "LoRA Directory"), inputs=[lora_dir], outputs=[status]) + gradio_temp_dir.blur(lambda v: handle_individual_setting_change("gradio_temp_dir", v, "Gradio Temporary Directory"), inputs=[gradio_temp_dir], outputs=[status]) + + auto_save.change(lambda v: handle_individual_setting_change("auto_save_settings", v, "Auto-save Settings"), inputs=[auto_save], outputs=[status]) + theme_dropdown.change(lambda v: handle_individual_setting_change("gradio_theme", v, "Theme"), inputs=[theme_dropdown], outputs=[status]) + + # Event handlers for startup settings + def update_startup_preset_dropdown_choices(selected_startup_model_type_from_ui): + if not selected_startup_model_type_from_ui or selected_startup_model_type_from_ui == "None": + return gr.update(choices=[], value=None) + + loaded_presets_for_model = load_presets(selected_startup_model_type_from_ui) + + # Get the preset name that was saved for the *previous* model type + current_saved_startup_preset = settings.get("startup_preset_name") + + # Default to None + value_to_select = None + # If the previously saved preset name exists for the new model, select it + if current_saved_startup_preset and current_saved_startup_preset in loaded_presets_for_model: + value_to_select = current_saved_startup_preset + + return gr.update(choices=loaded_presets_for_model, value=value_to_select) + + startup_model_type_dropdown.change( + fn=lambda v: handle_individual_setting_change("startup_model_type", v, "Startup Model Type"), + inputs=[startup_model_type_dropdown], outputs=[status] + ).then( # Chain the update to the preset dropdown + fn=update_startup_preset_dropdown_choices, inputs=[startup_model_type_dropdown], outputs=[startup_preset_name_dropdown]) + startup_preset_name_dropdown.change(lambda v: handle_individual_setting_change("startup_preset_name", v, "Startup Preset Name"), inputs=[startup_preset_name_dropdown], outputs=[status]) + + # --- Event Handlers and Connections (Now correctly indented) --- + + # --- Connect Monitoring --- + # Auto-check for current job on page load and job change + def check_for_current_job(): + # This function will be called when the interface loads + # It will check if there's a current job in the queue and update the UI + with job_queue.lock: + current_job = job_queue.current_job + if current_job: + # Return all the necessary information to update the preview windows + job_id = current_job.id + result = current_job.result + preview = current_job.progress_data.get('preview') if current_job.progress_data else None + desc = current_job.progress_data.get('desc', '') if current_job.progress_data else '' + html = current_job.progress_data.get('html', '') if current_job.progress_data else '' + + # Also trigger the monitor_job function to start monitoring this job + print(f"Auto-check found current job {job_id}, triggering monitor_job") + return job_id, result, preview, preview, desc, html + return None, None, None, None, '', '' + + # Auto-check for current job on page load and handle handoff between jobs. + def check_for_current_job_and_monitor(): + # This function is now the key to the handoff. + # It finds the current job and returns its ID, which will trigger the monitor. + job_id, result, preview, top_preview, desc, html = check_for_current_job() + # We also need to get fresh stats at the same time. + queue_status_data, queue_stats_text = update_stats() + # Return everything needed to update the UI atomically. + return job_id, result, preview, top_preview, desc, html, queue_status_data, queue_stats_text + + # Connect the main process function (wrapper for adding to queue) + def process_with_queue_update(model_type_arg, *args): + # Call update_stats to get both queue_status_data and queue_stats_text + queue_status_data, queue_stats_text = update_stats() # MODIFIED + + # Extract all arguments (ensure order matches inputs lists) + # The order here MUST match the order in the `ips` list. + # RT_BORG: Global settings gpu_memory_preservation, mp4_crf, save_metadata removed from direct args. + (input_image_arg, + input_video_arg, + end_frame_image_original_arg, + end_frame_strength_original_arg, + prompt_text_arg, + n_prompt_arg, + seed_arg, # the seed value + randomize_seed_arg, # the boolean value of the checkbox + total_second_length_arg, + latent_window_size_arg, + steps_arg, + cfg_arg, + gs_arg, + rs_arg, + cache_type_arg, + teacache_num_steps_arg, + teacache_rel_l1_thresh_arg, + magcache_threshold_arg, + magcache_max_consecutive_skips_arg, + magcache_retention_ratio_arg, + blend_sections_arg, + latent_type_arg, + clean_up_videos_arg, # UI checkbox from Generate tab + selected_loras_arg, + resolutionW_arg, resolutionH_arg, + combine_with_source_arg, + num_cleaned_frames_arg, + lora_names_states_arg, # This is from lora_names_states (gr.State) + *lora_slider_values_tuple # Remaining args are LoRA slider values + ) = args + # DO NOT parse the prompt here. Parsing happens once in the worker. + + # Determine the model type to send to the backend + backend_model_type = model_type_arg # model_type_arg is the UI selection + if model_type_arg == "Video with Endframe": + backend_model_type = "Video" # The backend "Video" model_type handles with and without endframe + + # Use the appropriate input based on model type + is_ui_video_model = is_video_model(model_type_arg) + input_data = input_video_arg if is_ui_video_model else input_image_arg + + # Define actual end_frame params to pass to backend + actual_end_frame_image_for_backend = None + actual_end_frame_strength_for_backend = 1.0 # Default strength + + if model_type_arg == "Original with Endframe" or model_type_arg == "F1 with Endframe" or model_type_arg == "Video with Endframe": + actual_end_frame_image_for_backend = end_frame_image_original_arg + actual_end_frame_strength_for_backend = end_frame_strength_original_arg + + # Get the input video path for Video model + input_image_path = None + if is_ui_video_model and input_video_arg is not None: + # For Video models, input_video contains the path to the video file + input_image_path = input_video_arg + + # Use the current seed value as is for this job + # Call the process function with all arguments + # Pass the backend_model_type and the ORIGINAL prompt_text string to the backend process function + result = process_fn(backend_model_type, input_data, actual_end_frame_image_for_backend, actual_end_frame_strength_for_backend, + prompt_text_arg, n_prompt_arg, seed_arg, total_second_length_arg, + latent_window_size_arg, steps_arg, cfg_arg, gs_arg, rs_arg, + cache_type_arg == 'TeaCache', teacache_num_steps_arg, teacache_rel_l1_thresh_arg, + cache_type_arg == 'MagCache', magcache_threshold_arg, magcache_max_consecutive_skips_arg, magcache_retention_ratio_arg, + blend_sections_arg, latent_type_arg, clean_up_videos_arg, # clean_up_videos_arg is from UI + selected_loras_arg, resolutionW_arg, resolutionH_arg, + input_image_path, + combine_with_source_arg, + num_cleaned_frames_arg, + lora_names_states_arg, + *lora_slider_values_tuple + ) + # If randomize_seed is checked, generate a new random seed for the next job + new_seed_value = None + if randomize_seed_arg: + new_seed_value = random.randint(0, 21474) + print(f"Generated new seed for next job: {new_seed_value}") + + # Create the button update for start_button WITHOUT interactive=True. + # The interactivity will be set by update_start_button_state later in the chain. + start_button_update_after_add = gr.update(value="🚀 Add to Queue") + + # If a job ID was created, automatically start monitoring it and update queue + if result and result[1]: # Check if job_id exists in results + job_id = result[1] + # queue_status_data = update_queue_status_fn() # OLD: update_stats now called earlier + # Call update_stats again AFTER the job is added to get the freshest stats + queue_status_data, queue_stats_text = update_stats() + + + # Add the new seed value to the results if randomize is checked + if new_seed_value is not None: + # Use result[6] directly for end_button to preserve its value. Add gr.update() for video_input_required_message. + return [result[0], job_id, result[2], result[3], result[4], start_button_update_after_add, result[6], queue_status_data, queue_stats_text, new_seed_value, gr.update()] + else: + # Use result[6] directly for end_button to preserve its value. Add gr.update() for video_input_required_message. + return [result[0], job_id, result[2], result[3], result[4], start_button_update_after_add, result[6], queue_status_data, queue_stats_text, gr.update(), gr.update()] + + # If no job ID was created, still return the new seed if randomize is checked + # Also, ensure we return the latest stats even if no job was created (e.g., error during param validation) + queue_status_data, queue_stats_text = update_stats() + if new_seed_value is not None: + # Make sure to preserve the end_button update from result[6] + return [result[0], result[1], result[2], result[3], result[4], start_button_update_after_add, result[6], queue_status_data, queue_stats_text, new_seed_value, gr.update()] + else: + # Make sure to preserve the end_button update from result[6] + return [result[0], result[1], result[2], result[3], result[4], start_button_update_after_add, result[6], queue_status_data, queue_stats_text, gr.update(), gr.update()] + + # Custom end process function that ensures the queue is updated and changes button text + def end_process_with_update(): + _ = end_process_fn() # Call the original end_process_fn + # Now, get fresh stats for both queue table and toolbar + queue_status_data, queue_stats_text = update_stats() + + # Don't try to get the new job ID immediately after cancellation + # The monitor_job function will handle the transition to the next job + + # Change the cancel button text to "Cancelling..." and make it non-interactive + # This ensures the button stays in this state until the job is fully cancelled + return queue_status_data, queue_stats_text, gr.update(value="Cancelling...", interactive=False), gr.update(value=None) + + # MODIFIED handle_send_video_to_toolbox: + def handle_send_video_to_toolbox(original_path_from_state): # Input is now the original path from gr.State + print(f"Sending selected Outputs' video to Post-processing: {original_path_from_state}") + + if original_path_from_state and isinstance(original_path_from_state, str) and os.path.exists(original_path_from_state): + # tb_target_video_input will now process the ORIGINAL path (e.g., "outputs/my_actual_video.mp4"). + return gr.update(value=original_path_from_state), gr.update(selected="toolbox_tab") + else: + print(f"No valid video path (from State) found to send. Path: {original_path_from_state}") + return gr.update(), gr.update() + + send_to_toolbox_btn.click( + fn=handle_send_video_to_toolbox, + inputs=[selected_original_video_path_state], # INPUT IS NOW THE gr.State holding the ORIGINAL path + outputs=[ + tb_target_video_input, # This is tb_input_video_component from toolbox_app.py + main_tabs_component + ] + ) + + # --- Inputs Lists --- + # --- Inputs for all models --- + ips = [ + input_image, # Corresponds to input_image_arg + input_video, # Corresponds to input_video_arg + end_frame_image_original, # Corresponds to end_frame_image_original_arg + end_frame_strength_original,# Corresponds to end_frame_strength_original_arg + prompt, # Corresponds to prompt_text_arg + n_prompt, # Corresponds to n_prompt_arg + seed, # Corresponds to seed_arg + randomize_seed, # Corresponds to randomize_seed_arg + total_second_length, # Corresponds to total_second_length_arg + latent_window_size, # Corresponds to latent_window_size_arg + steps, # Corresponds to steps_arg + cfg, # Corresponds to cfg_arg + gs, # Corresponds to gs_arg + rs, # Corresponds to rs_arg + cache_type, # Corresponds to cache_type_arg + teacache_num_steps, # Corresponds to teacache_num_steps_arg + teacache_rel_l1_thresh, # Corresponds to teacache_rel_l1_thresh_arg + magcache_threshold, # Corresponds to magcache_threshold_arg + magcache_max_consecutive_skips, # Corresponds to magcache_max_consecutive_skips_arg + magcache_retention_ratio, # Corresponds to magcache_retention_ratio_arg + blend_sections, # Corresponds to blend_sections_arg + latent_type, # Corresponds to latent_type_arg + clean_up_videos, # Corresponds to clean_up_videos_arg (UI checkbox) + lora_selector, # Corresponds to selected_loras_arg + resolutionW, # Corresponds to resolutionW_arg + resolutionH, # Corresponds to resolutionH_arg + combine_with_source, # Corresponds to combine_with_source_arg + num_cleaned_frames, # Corresponds to num_cleaned_frames_arg + lora_names_states # Corresponds to lora_names_states_arg + ] + # Add LoRA sliders to the input list + ips.extend([lora_sliders[lora] for lora in lora_names]) + + + # --- Connect Buttons --- + def handle_start_button(selected_model, *args): + # For other model types, use the regular process function + return process_with_queue_update(selected_model, *args) + + def handle_batch_add_to_queue(*args): + # The last argument will be the list of files from batch_input_images + batch_files = args[-1] + if not batch_files or not isinstance(batch_files, list): + print("No batch images provided.") + return + + print(f"Starting batch processing for {len(batch_files)} images.") + + # Reconstruct the arguments for the single process function, excluding the batch files list + single_job_args = list(args[:-1]) + + # The first argument to process_with_queue_update is model_type + model_type_arg = single_job_args.pop(0) + + # Keep track of the seed + current_seed = single_job_args[6] # seed is the 7th element in the ips list + randomize_seed_arg = single_job_args[7] # randomize_seed is the 8th + + for image_path in batch_files: + # --- FIX IS HERE --- + # Load the image from the path into a NumPy array + try: + pil_image = Image.open(image_path).convert("RGB") + numpy_image = np.array(pil_image) + except Exception as e: + print(f"Error loading batch image {image_path}: {e}. Skipping.") + continue + # --- END OF FIX --- + + # Replace the single input_image argument with the loaded NumPy image + current_job_args = single_job_args[:] + current_job_args[0] = numpy_image # Use the loaded numpy_image + current_job_args[6] = current_seed # Set the seed for the current job + + # Call the original processing function with the modified arguments + process_with_queue_update(model_type_arg, *current_job_args) + + # If randomize seed is checked, generate a new one for the next image + if randomize_seed_arg: + current_seed = random.randint(0, 21474) + + print("Batch processing complete. All jobs added to the queue.") + + # Validation ensures the start button is only enabled when appropriate + def update_start_button_state(*args): + """ + Validation fails if a video model is selected and no input video is provided. + Updates the start button interactivity and validation message visibility. + Handles variable inputs from different Gradio event chains. + """ + # The required values are the last two arguments provided by the Gradio event + if len(args) >= 2: + selected_model = args[-2] + input_video_value = args[-1] + else: + # Fallback or error handling if not enough arguments are received + # This might happen if the event is triggered in an unexpected way + print(f"Warning: update_start_button_state received {len(args)} args, expected at least 2.") + # Default to a safe state (button disabled) + return gr.Button(value="❌ Error", interactive=False), gr.update(visible=True) + + video_provided = input_video_value is not None + + if is_video_model(selected_model) and not video_provided: + # Video model selected, but no video provided + return gr.Button(value="❌ Missing Video", interactive=False), gr.update(visible=True) + else: + # Either not a video model, or video model selected and video provided + return gr.update(value="🚀 Add to Queue", interactive=True), gr.update(visible=False) + # Function to update button state before processing + def update_button_before_processing(selected_model, *args): + # First update the button to show "Adding..." and disable it + # Also return current stats so they don't get blanked out during the "Adding..." phase + qs_data, qs_text = update_stats() + return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(value="⏳ Adding...", interactive=False), gr.update(), qs_data, qs_text, gr.update(), gr.update() # Added update for video_input_required_message + + # Connect the start button to first update its state + start_button.click( + fn=update_button_before_processing, + inputs=[model_type] + ips, + outputs=[result_video, current_job_id, preview_image, top_preview_image, progress_desc, progress_bar, start_button, end_button, queue_status, queue_stats_display, seed, video_input_required_message] + ).then( + # Then process the job + fn=handle_start_button, + inputs=[model_type] + ips, + outputs=[result_video, current_job_id, preview_image, progress_desc, progress_bar, start_button, end_button, queue_status, queue_stats_display, seed, video_input_required_message] # Added video_input_required_message + ).then( # Ensure validation is re-checked after job processing completes + fn=update_start_button_state, + inputs=[model_type, input_video], # Current values of model_type and input_video + outputs=[start_button, video_input_required_message] + ) + + def show_batch_gallery(files): + return gr.update(value=files, visible=True) if files else gr.update(visible=False) + + batch_input_images.change( + fn=show_batch_gallery, + inputs=[batch_input_images], + outputs=[batch_input_gallery] + ) + + # We need to gather all the same inputs as the single 'Add to Queue' button, plus the new file input + batch_ips = [model_type] + ips + [batch_input_images] + + add_batch_to_queue_btn.click( + fn=handle_batch_add_to_queue, + inputs=batch_ips, + outputs=None # No direct output updates from this button + ).then( + fn=update_stats, # Refresh the queue stats in the UI + inputs=None, + outputs=[queue_status, queue_stats_display] + ).then( + # This new block checks for a running job and updates the monitor UI + fn=check_for_current_job, + inputs=None, + outputs=[current_job_id, result_video, preview_image, top_preview_image, progress_desc, progress_bar] + ).then( + # NEW: Update latents display layout after loading queue to ensure correct visibility + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + # --- START OF REFACTORED XY PLOT EVENT WIRING --- + # Get the process button from the created components + xy_plot_process_btn = xy_plot_components["process_btn"] + + # Prepare the process function with its static dependencies (job_queue, settings) + fn_xy_process_with_deps = functools.partial(xy_plot_process, job_queue, settings) + + # Construct the full list of inputs for the click handler in the correct order + c = xy_plot_components + xy_plot_input_components = [ + c["model_type"], c["input_image"], c["end_frame_image_original"], + c["end_frame_strength_original"], c["latent_type"], c["prompt"], + c["blend_sections"], c["steps"], c["total_second_length"], + resolutionW, resolutionH, # The components from the main UI + c["seed"], c["randomize_seed"], + c["use_teacache"], c["teacache_num_steps"], c["teacache_rel_l1_thresh"], + c["use_magcache"], c["magcache_threshold"], c["magcache_max_consecutive_skips"], c["magcache_retention_ratio"], + c["latent_window_size"], c["cfg"], c["gs"], c["rs"], + c["gpu_memory_preservation"], c["mp4_crf"], + c["axis_x_switch"], c["axis_x_value_text"], c["axis_x_value_dropdown"], + c["axis_y_switch"], c["axis_y_value_text"], c["axis_y_value_dropdown"], + c["axis_z_switch"], c["axis_z_value_text"], c["axis_z_value_dropdown"], + c["lora_selector"] + ] + # LoRA sliders are in a dictionary, so we add their values to the list + xy_plot_input_components.extend(c["lora_sliders"].values()) + + # Wire the click handler for the XY Plot button + xy_plot_process_btn.click( + fn=fn_xy_process_with_deps, + inputs=xy_plot_input_components, + outputs=[xy_plot_status, xy_plot_output] + ).then( + fn=update_stats, + inputs=None, + outputs=[queue_status, queue_stats_display] + ).then( + fn=check_for_current_job, + inputs=None, + outputs=[current_job_id, result_video, preview_image, top_preview_image, progress_desc, progress_bar] + ).then( + # NEW: Update latents display layout after XY plot to ensure correct visibility + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + # --- END OF REFACTORED XY PLOT EVENT WIRING --- + + + + # MODIFIED: on_model_type_change to handle new "XY Plot" option + def on_model_type_change(selected_model): + is_xy_plot = selected_model == "XY Plot" + is_ui_video_model_flag = is_video_model(selected_model) + shows_end_frame = selected_model in ["Original with Endframe", "Video with Endframe"] + + return ( + gr.update(visible=not is_xy_plot), # standard_generation_group + gr.update(visible=is_xy_plot), # xy_group + gr.update(visible=not is_xy_plot and not is_ui_video_model_flag), # image_input_group + gr.update(visible=not is_xy_plot and is_ui_video_model_flag), # video_input_group + gr.update(visible=not is_xy_plot and shows_end_frame), # end_frame_group_original + gr.update(visible=not is_xy_plot and shows_end_frame), # end_frame_slider_group + gr.update(visible=not is_xy_plot), # start_button + gr.update(visible=is_xy_plot) # xy_plot_process_btn + ) + + # Model change listener + model_type.change( + fn=on_model_type_change, + inputs=model_type, + outputs=[ + standard_generation_group, + xy_group, + image_input_group, + video_input_group, + end_frame_group_original, + end_frame_slider_group, + start_button, + xy_plot_process_btn # This is the button returned from the dictionary + ] + ).then( # Also trigger validation after model type changes + fn=update_start_button_state, + inputs=[model_type, input_video], + outputs=[start_button, video_input_required_message] + ) + + # Connect input_video change to the validation function + input_video.change( + fn=update_start_button_state, + inputs=[model_type, input_video], + outputs=[start_button, video_input_required_message] + ) + # Also trigger validation when video is cleared + input_video.clear( + fn=update_start_button_state, + inputs=[model_type, input_video], + outputs=[start_button, video_input_required_message] + ) + + + + # Auto-monitor the current job when job_id changes + current_job_id.change( + fn=monitor_fn, + inputs=[current_job_id], + outputs=[result_video, preview_image, top_preview_image, progress_desc, progress_bar, start_button, end_button] + ).then( + fn=update_stats, # When a monitor finishes, always update the stats. + inputs=None, + outputs=[queue_status, queue_stats_display] + ).then( # re-validate button state + fn=update_start_button_state, + inputs=[model_type, input_video], + outputs=[start_button, video_input_required_message] + ).then( + # NEW: Update latents display layout after monitoring to ensure correct visibility + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + # The "end_button" (Cancel Job) is the trigger for the next job's monitor. + # When a job is cancelled, we check for the next one. + end_button.click( + fn=end_process_with_update, + outputs=[queue_status, queue_stats_display, end_button, current_job_id] + ).then( + fn=check_for_current_job_and_monitor, + inputs=[], + outputs=[current_job_id, result_video, preview_image, top_preview_image, progress_desc, progress_bar, queue_status, queue_stats_display] + ).then( + # NEW: Update latents display layout after job handoff to ensure correct visibility + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + load_queue_button.click( + fn=load_queue_from_json, + inputs=[], + outputs=[queue_status, queue_stats_display] + ).then( # ADD THIS .then() CLAUSE + fn=check_for_current_job, + inputs=[], + outputs=[current_job_id, result_video, preview_image, top_preview_image, progress_desc, progress_bar] + ).then( + # NEW: Update latents display layout after loading queue to ensure correct visibility + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + import_queue_file.change( + fn=import_queue_from_file, + inputs=[import_queue_file], + outputs=[queue_status, queue_stats_display] + ).then( # ADD THIS .then() CLAUSE + fn=check_for_current_job, + inputs=[], + outputs=[current_job_id, result_video, preview_image, top_preview_image, progress_desc, progress_bar] + ).then( + # NEW: Update latents display layout after importing queue to ensure correct visibility + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + + # --- Connect Queue Refresh --- + # The update_stats function is now defined much earlier. + + # REMOVED: refresh_stats_btn.click - Toolbar refresh button is no longer needed + # refresh_stats_btn.click( + # fn=update_stats, + # inputs=None, + # outputs=[queue_status, queue_stats_display] + # ) + + # Set up auto-refresh for queue status + # Instead of using a timer with 'every' parameter, we'll use the queue refresh button + # and rely on manual refreshes. The user can click the refresh button in the toolbar + # to update the stats. + + # --- Connect LoRA UI --- + # Function to update slider visibility based on selection + def update_lora_sliders(selected_loras): + updates = [] + # Suppress dummy LoRA from workaround for the single lora bug. + # Filter out the dummy LoRA for display purposes in the dropdown + actual_selected_loras_for_display = [lora for lora in selected_loras if lora != DUMMY_LORA_NAME] + updates.append(gr.update(value=actual_selected_loras_for_display)) # First update is for the dropdown itself + + # Need to handle potential missing keys if lora_names changes dynamically + # lora_names is from the create_interface scope + for lora_name_key in lora_names: # Iterate using lora_names to maintain order + if lora_name_key == DUMMY_LORA_NAME: # Check for dummy LoRA + updates.append(gr.update(visible=False)) + else: + # Visibility of sliders should be based on actual_selected_loras_for_display + updates.append(gr.update(visible=(lora_name_key in actual_selected_loras_for_display))) + return updates # This list will be correctly ordered + + # Connect the dropdown to the sliders + lora_selector.change( + fn=update_lora_sliders, + inputs=[lora_selector], + outputs=[lora_selector] + [lora_sliders[lora] for lora in lora_names if lora in lora_sliders] + ) + + def apply_preset(preset_name, model_type): + if not preset_name: + # Create a list of empty updates matching the number of components + return [gr.update()] * len(ui_components) + + with open(PRESET_FILE, 'r') as f: + data = json.load(f) + preset = data.get(model_type, {}).get(preset_name, {}) + + # Initialize updates for all components + updates = {key: gr.update() for key in ui_components.keys()} + + # Update components based on the preset + for key, value in preset.items(): + if key in updates: + updates[key] = gr.update(value=value) + + # Handle LoRA sliders specifically + if 'lora_values' in preset and isinstance(preset['lora_values'], dict): + lora_values_dict = preset['lora_values'] + for lora_name, lora_value in lora_values_dict.items(): + if lora_name in updates: + updates[lora_name] = gr.update(value=lora_value) + + # Convert the dictionary of updates to a list in the correct order + return [updates[key] for key in ui_components.keys()] + + def save_preset(preset_name, model_type, *args): + if not preset_name: + return gr.update() + + # Ensure the directory exists + os.makedirs(os.path.dirname(PRESET_FILE), exist_ok=True) + + if not os.path.exists(PRESET_FILE): + with open(PRESET_FILE, 'w') as f: + json.dump({}, f) + + with open(PRESET_FILE, 'r') as f: + data = json.load(f) + + if model_type not in data: + data[model_type] = {} + + keys = list(ui_components.keys()) + + # Create a dictionary from the passed arguments + args_dict = {keys[i]: args[i] for i in range(len(keys))} + + # Build the preset data from the arguments dictionary + preset_data = {key: args_dict[key] for key in ui_components.keys() if key not in lora_sliders} + + # Handle LoRA values separately + selected_loras = args_dict.get("lora_selector", []) + lora_values = {} + for lora_name in selected_loras: + if lora_name in args_dict: + lora_values[lora_name] = args_dict[lora_name] + + preset_data['lora_values'] = lora_values + + # Remove individual lora sliders from the top-level preset data + for lora_name in lora_sliders: + if lora_name in preset_data: + del preset_data[lora_name] + + data[model_type][preset_name] = preset_data + + with open(PRESET_FILE, 'w') as f: + json.dump(data, f, indent=2) + + return gr.update(choices=load_presets(model_type), value=preset_name) + + def delete_preset(preset_name, model_type): + if not preset_name: + return gr.update(), gr.update(visible=True), gr.update(visible=False) + + with open(PRESET_FILE, 'r') as f: + data = json.load(f) + + if model_type in data and preset_name in data[model_type]: + del data[model_type][preset_name] + + with open(PRESET_FILE, 'w') as f: + json.dump(data, f, indent=2) + + return gr.update(choices=load_presets(model_type), value=None), gr.update(visible=True), gr.update(visible=False) + + # --- Connect Preset UI --- + # Without this refresh, if you define a new preset for the Startup Model Type, and then try to select it in settings, it won't show up. + def refresh_settings_tab_startup_presets_if_needed(generate_tab_model_type_value, settings_tab_startup_model_type_value): + # generate_tab_model_type_value is the model for which a preset was just saved + # settings_tab_startup_model_type_value is the current selection in the startup model dropdown on settings tab + if generate_tab_model_type_value == settings_tab_startup_model_type_value and settings_tab_startup_model_type_value != "None": + return update_startup_preset_dropdown_choices(settings_tab_startup_model_type_value) + return gr.update() + + ui_components = { + # Prompts + "prompt": prompt, + "n_prompt": n_prompt, + "blend_sections": blend_sections, + # Basic Params + "steps": steps, + "total_second_length": total_second_length, + "resolutionW": resolutionW, + "resolutionH": resolutionH, + "seed": seed, + "randomize_seed": randomize_seed, + # Advanced Params + "gs": gs, + "cfg": cfg, + "rs": rs, + "latent_window_size": latent_window_size, + # Cache type (Mag/Tea/None) + "cache_type": cache_type, + # TeaCache + "teacache_num_steps": teacache_num_steps, + "teacache_rel_l1_thresh": teacache_rel_l1_thresh, + # MagCache + "magcache_threshold": magcache_threshold, + "magcache_max_consecutive_skips": magcache_max_consecutive_skips, + "magcache_retention_ratio": magcache_retention_ratio, + # Input Options + "latent_type": latent_type, + "end_frame_strength_original": end_frame_strength_original, + # Video Specific + "combine_with_source": combine_with_source, + "num_cleaned_frames": num_cleaned_frames, + # LoRAs + "lora_selector": lora_selector, + **lora_sliders + } + + model_type.change( + fn=lambda mt: (gr.update(choices=load_presets(mt)), gr.update(label=f"{mt} Presets")), + inputs=[model_type], + outputs=[preset_dropdown, preset_accordion] + ) + + preset_dropdown.select( + fn=apply_preset, + inputs=[preset_dropdown, model_type], + outputs=list(ui_components.values()) + ).then( + lambda name: name, + inputs=[preset_dropdown], + outputs=[preset_name_textbox] + ) + + save_preset_button.click( + fn=save_preset, + inputs=[preset_name_textbox, model_type, *list(ui_components.values())], + outputs=[preset_dropdown] # preset_dropdown is on Generate tab + ).then( + fn=refresh_settings_tab_startup_presets_if_needed, + inputs=[model_type, startup_model_type_dropdown], # model_type (Generate tab), startup_model_type_dropdown (Settings tab) + outputs=[startup_preset_name_dropdown] # startup_preset_name_dropdown (Settings tab) + ) + + def show_delete_confirmation(): + return gr.update(visible=False), gr.update(visible=True) + + def hide_delete_confirmation(): + return gr.update(visible=True), gr.update(visible=False) + + delete_preset_button.click( + fn=show_delete_confirmation, + outputs=[save_preset_button, confirm_delete_row] + ) + + confirm_delete_no_btn.click( + fn=hide_delete_confirmation, + outputs=[save_preset_button, confirm_delete_row] + ) + + confirm_delete_yes_btn.click( + fn=delete_preset, + inputs=[preset_dropdown, model_type], + outputs=[preset_dropdown, save_preset_button, confirm_delete_row] + ) + + # --- Definition of apply_startup_settings (AFTER ui_components and apply_preset are defined) --- + # This function needs access to `settings`, `model_type` (Generate tab Radio), + # `preset_dropdown` (Generate tab Dropdown), `preset_name_textbox` (Generate tab Textbox), + # `ui_components` (dict of all other UI elements), `load_presets`, and `apply_preset`. + # All these are available in the scope of `create_interface`. + def apply_startup_settings(): + startup_model_val = settings.get("startup_model_type", "None") + startup_preset_val = settings.get("startup_preset_name", None) + + # Default updates (no change) + model_type_update = gr.update() + preset_dropdown_update = gr.update() + preset_name_textbox_update = gr.update() + + # ui_components is now defined + ui_components_updates_list = [gr.update() for _ in ui_components] + + if startup_model_val and startup_model_val != "None": + model_type_update = gr.update(value=startup_model_val) + + presets_for_startup_model = load_presets(startup_model_val) # load_presets is defined earlier + preset_dropdown_update = gr.update(choices=presets_for_startup_model) + preset_name_textbox_update = gr.update(value="") + + if startup_preset_val and startup_preset_val in presets_for_startup_model: + preset_dropdown_update = gr.update(choices=presets_for_startup_model, value=startup_preset_val) + preset_name_textbox_update = gr.update(value=startup_preset_val) + + # apply_preset is now defined + ui_components_updates_list = apply_preset(startup_preset_val, startup_model_val) + + # NEW: Ensure latents_display_top checkbox reflects the current setting + latents_display_top_update = gr.update(value=get_latents_display_top()) + + return tuple([model_type_update, preset_dropdown_update, preset_name_textbox_update] + ui_components_updates_list + [latents_display_top_update]) + + + # --- Auto-refresh for Toolbar System Stats Monitor (Timer) --- + main_toolbar_system_stats_timer = gr.Timer(2, active=True) + + main_toolbar_system_stats_timer.tick( + fn=tb_get_formatted_toolbar_stats, # Function imported from toolbox_app.py + inputs=None, + outputs=[ # Target the Textbox components + toolbar_ram_display_component, + toolbar_vram_display_component, + toolbar_gpu_display_component + ] + ) + + # --- Connect Metadata Loading --- + # Function to load metadata from JSON file + def load_metadata_from_json(json_path): + # Define the total number of output components to handle errors gracefully + num_outputs = 20 + len(lora_sliders) + + if not json_path: + # Return empty updates for all components if no file is provided + return [gr.update()] * num_outputs + + try: + with open(json_path, 'r') as f: + metadata = json.load(f) + + # Extract values from metadata with defaults + prompt_val = metadata.get('prompt') + n_prompt_val = metadata.get('negative_prompt') + seed_val = metadata.get('seed') + steps_val = metadata.get('steps') + total_second_length_val = metadata.get('total_second_length') + end_frame_strength_val = metadata.get('end_frame_strength') + model_type_val = metadata.get('model_type') + lora_weights = metadata.get('loras', {}) + latent_window_size_val = metadata.get('latent_window_size') + resolutionW_val = metadata.get('resolutionW') + resolutionH_val = metadata.get('resolutionH') + blend_sections_val = metadata.get('blend_sections') + # Determine cache_type from metadata, with fallback for older formats + cache_type_val = metadata.get('cache_type') + if cache_type_val is None: + use_magcache = metadata.get('use_magcache', False) + use_teacache = metadata.get('use_teacache', False) + if use_magcache: + cache_type_val = "MagCache" + elif use_teacache: + cache_type_val = "TeaCache" + else: + cache_type_val = "None" + magcache_threshold_val = metadata.get('magcache_threshold') + magcache_max_consecutive_skips_val = metadata.get('magcache_max_consecutive_skips') + magcache_retention_ratio_val = metadata.get('magcache_retention_ratio') + teacache_num_steps_val = metadata.get('teacache_num_steps') + teacache_rel_l1_thresh_val = metadata.get('teacache_rel_l1_thresh') + latent_type_val = metadata.get('latent_type') + combine_with_source_val = metadata.get('combine_with_source') + + # Get the names of the selected LoRAs from the metadata + selected_lora_names = list(lora_weights.keys()) + + print(f"Loaded metadata from JSON: {json_path}") + print(f"Model Type: {model_type_val}, Prompt: {prompt_val}, Seed: {seed_val}, LoRAs: {selected_lora_names}") + + # Create a list of UI updates + updates = [ + gr.update(value=prompt_val) if prompt_val is not None else gr.update(), + gr.update(value=n_prompt_val) if n_prompt_val is not None else gr.update(), + gr.update(value=seed_val) if seed_val is not None else gr.update(), + gr.update(value=steps_val) if steps_val is not None else gr.update(), + gr.update(value=total_second_length_val) if total_second_length_val is not None else gr.update(), + gr.update(value=end_frame_strength_val) if end_frame_strength_val is not None else gr.update(), + gr.update(value=model_type_val) if model_type_val else gr.update(), + gr.update(value=selected_lora_names) if selected_lora_names else gr.update(), + gr.update(value=latent_window_size_val) if latent_window_size_val is not None else gr.update(), + gr.update(value=resolutionW_val) if resolutionW_val is not None else gr.update(), + gr.update(value=resolutionH_val) if resolutionH_val is not None else gr.update(), + gr.update(value=blend_sections_val) if blend_sections_val is not None else gr.update(), + gr.update(value=cache_type_val), + gr.update(value=magcache_threshold_val), + gr.update(value=magcache_max_consecutive_skips_val), + gr.update(value=magcache_retention_ratio_val), + gr.update(value=teacache_num_steps_val) if teacache_num_steps_val is not None else gr.update(), + gr.update(value=teacache_rel_l1_thresh_val) if teacache_rel_l1_thresh_val is not None else gr.update(), + gr.update(value=latent_type_val) if latent_type_val else gr.update(), + gr.update(value=combine_with_source_val) if combine_with_source_val else gr.update(), + ] + + # Update LoRA sliders based on loaded weights + for lora in lora_names: + if lora in lora_weights: + updates.append(gr.update(value=lora_weights[lora], visible=True)) + else: + # Hide sliders for LoRAs not in the metadata + updates.append(gr.update(visible=False)) + + return updates + + except Exception as e: + print(f"Error loading metadata: {e}") + import traceback + traceback.print_exc() + # Return empty updates for all components on error + return [gr.update()] * num_outputs + + + # Connect JSON metadata loader for Original tab + json_upload.change( + fn=load_metadata_from_json, + inputs=[json_upload], + outputs=[ + prompt, + n_prompt, + seed, + steps, + total_second_length, + end_frame_strength_original, + model_type, + lora_selector, + latent_window_size, + resolutionW, + resolutionH, + blend_sections, + cache_type, + magcache_threshold, + magcache_max_consecutive_skips, + magcache_retention_ratio, + teacache_num_steps, + teacache_rel_l1_thresh, + latent_type, + combine_with_source + ] + [lora_sliders[lora] for lora in lora_names] + ) + + + # --- Helper Functions (defined within create_interface scope if needed by handlers) --- + # Function to get queue statistics + def get_queue_stats(): + try: + # Get all jobs from the queue + jobs = job_queue.get_all_jobs() + + # Count jobs by status + status_counts = { + "QUEUED": 0, + "RUNNING": 0, + "COMPLETED": 0, + "FAILED": 0, + "CANCELLED": 0 + } + + for job in jobs: + if hasattr(job, 'status'): + status = str(job.status) # Use str() for safety + if status in status_counts: + status_counts[status] += 1 + + # Format the display text + stats_text = f"Queue: {status_counts['QUEUED']} | Running: {status_counts['RUNNING']} | Completed: {status_counts['COMPLETED']} | Failed: {status_counts['FAILED']} | Cancelled: {status_counts['CANCELLED']}" + + return f"

{stats_text}

" + + except Exception as e: + print(f"Error getting queue stats: {e}") + return "

Error loading queue stats

" + + # Add footer with social links + with gr.Row(elem_id="footer"): + with gr.Column(scale=1): + gr.HTML(f""" +
+
+ {APP_VERSION_DISPLAY} + + Support on Patreon + + + Discord + + + GitHub + +
+
+ """) + + # Add CSS for footer + + # gr.HTML(""" + # + # """) + + # --- Function to update latents display layout on interface load --- + def update_latents_layout_on_load(): + """Update latents display layout based on saved setting when interface loads""" + return create_latents_layout_update() + + # Connect the auto-check function to the interface load event + block.load( + fn=check_for_current_job_and_monitor, # Use the new combined function + inputs=[], + outputs=[current_job_id, result_video, preview_image, top_preview_image, progress_desc, progress_bar, queue_status, queue_stats_display] + + ).then( + fn=apply_startup_settings, # apply_startup_settings is now defined + inputs=None, + outputs=[model_type, preset_dropdown, preset_name_textbox] + list(ui_components.values()) + [latents_display_top] # ui_components is now defined + ).then( + fn=update_start_button_state, # Ensure button state is correct after startup settings + inputs=[model_type, input_video], + outputs=[start_button, video_input_required_message] + ).then( + # NEW: Update latents display layout based on saved setting + fn=create_latents_layout_update, + inputs=None, + outputs=[top_preview_row, preview_image] + ) + + # --- Prompt Enhancer Connection --- + def handle_enhance_prompt(current_prompt_text): + """Calls the LLM enhancer and returns the updated text.""" + if not current_prompt_text: + return "" + print("UI: Enhance button clicked. Sending prompt to enhancer.") + enhanced_text = enhance_prompt(current_prompt_text) + print(f"UI: Received enhanced prompt: {enhanced_text}") + return gr.update(value=enhanced_text) + + enhance_prompt_btn.click( + fn=handle_enhance_prompt, + inputs=[prompt], + outputs=[prompt] + ) + + # --- Captioner Connection --- + def handle_caption(input_image, prompt): + """Calls the LLM enhancer and returns the updated text.""" + if input_image is None: + return prompt # Return current prompt if no image is provided + caption_text = caption_image(input_image) + print(f"UI: Received caption: {caption_text}") + return gr.update(value=caption_text) + + caption_btn.click( + fn=handle_caption, + inputs=[input_image, prompt], + outputs=[prompt] + ) + + return block + +# --- Top-level Helper Functions (Used by Gradio callbacks, must be defined outside create_interface) --- + +def format_queue_status(jobs): + """Format job data for display in the queue status table""" + rows = [] + for job in jobs: + created = time.strftime('%H:%M:%S', time.localtime(job.created_at)) if job.created_at else "" + started = time.strftime('%H:%M:%S', time.localtime(job.started_at)) if job.started_at else "" + completed = time.strftime('%H:%M:%S', time.localtime(job.completed_at)) if job.completed_at else "" + + # Calculate elapsed time + elapsed_time = "" + if job.started_at: + if job.completed_at: + start_datetime = datetime.datetime.fromtimestamp(job.started_at) + complete_datetime = datetime.datetime.fromtimestamp(job.completed_at) + elapsed_seconds = (complete_datetime - start_datetime).total_seconds() + elapsed_time = f"{elapsed_seconds:.2f}s" + else: + # For running jobs, calculate elapsed time from now + start_datetime = datetime.datetime.fromtimestamp(job.started_at) + current_datetime = datetime.datetime.now() + elapsed_seconds = (current_datetime - start_datetime).total_seconds() + elapsed_time = f"{elapsed_seconds:.2f}s (running)" + + # Get generation type from job data + generation_type = getattr(job, 'generation_type', 'Original') + + # Get thumbnail from job data and format it as HTML for display + thumbnail = getattr(job, 'thumbnail', None) + thumbnail_html = f'' if thumbnail else "" + + rows.append([ + job.id[:6] + '...', + generation_type, + job.status.value, + created, + started, + completed, + elapsed_time, + thumbnail_html # Add formatted thumbnail HTML to row data + ]) + return rows + +# Create the queue status update function (wrapper around format_queue_status) +def update_queue_status_with_thumbnails(): # Function name is now slightly misleading, but keep for now to avoid breaking clicks + # This function is likely called by the refresh button and potentially the timer + # It needs access to the job_queue object + # Assuming job_queue is accessible globally or passed appropriately + # For now, let's assume it's globally accessible as defined in studio.py + # If not, this needs adjustment based on how job_queue is managed. + try: + # Need access to the global job_queue instance from studio.py + # This might require restructuring or passing job_queue differently. + # For now, assuming it's accessible (this might fail if run standalone) + from __main__ import job_queue # Attempt to import from main script scope + + jobs = job_queue.get_all_jobs() + for job in jobs: + if job.status == JobStatus.PENDING: + job.queue_position = job_queue.get_queue_position(job.id) + + if job_queue.current_job: + job_queue.current_job.status = JobStatus.RUNNING + + return format_queue_status(jobs) + except ImportError: + print("Error: Could not import job_queue. Queue status update might fail.") + return [] # Return empty list on error + except Exception as e: + print(f"Error updating queue status: {e}") + return [] diff --git a/modules/llm_captioner.py b/modules/llm_captioner.py new file mode 100644 index 0000000000000000000000000000000000000000..e5aeb6fb5175b96f5c49462bf25da537dd12d66b --- /dev/null +++ b/modules/llm_captioner.py @@ -0,0 +1,66 @@ +import torch +from PIL import Image +import numpy as np +from transformers import AutoProcessor, AutoModelForCausalLM + +device = "cuda:0" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + +model = None +processor = None + +def _load_captioning_model(): + """Load the Florence-2""" + global model, processor + if model is None or processor is None: + print("Loading Florence-2 model for image captioning...") + model = AutoModelForCausalLM.from_pretrained( + "microsoft/Florence-2-large", + torch_dtype=torch_dtype, + trust_remote_code=True + ).to(device) + + processor = AutoProcessor.from_pretrained( + "microsoft/Florence-2-large", + trust_remote_code=True + ) + print("Florence-2 model loaded successfully.") + +def unload_captioning_model(): + """Unload the Florence-2""" + global model, processor + if model is not None: + del model + model = None + if processor is not None: + del processor + processor = None + torch.cuda.empty_cache() + print("Florence-2 model unloaded successfully.") + +prompt = "" + +# The image parameter now directly accepts a PIL Image object +def caption_image(image: np.array): + """ + Args: + image_np (np.ndarray): The input image as a NumPy array (e.g., HxWx3, RGB). + Gradio passes this when type="numpy" is set. + """ + + _load_captioning_model() + + image_pil = Image.fromarray(image) + + inputs = processor(text=prompt, images=image_pil, return_tensors="pt").to(device, torch_dtype) + + generated_ids = model.generate( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"], + max_new_tokens=1024, + num_beams=3, + do_sample=False + ) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + return generated_text diff --git a/modules/llm_enhancer.py b/modules/llm_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..72925930abe8d10a854ddc2506e3eeb70db69486 --- /dev/null +++ b/modules/llm_enhancer.py @@ -0,0 +1,191 @@ +import re +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +# --- Configuration --- +# Using a smaller, faster model for this feature. +# This can be moved to a settings file later. +MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct" +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +SYSTEM_PROMPT= ( + "You are a tool to enhance descriptions of scenes, aiming to rewrite user " + "input into high-quality prompts for increased coherency and fluency while " + "strictly adhering to the original meaning.\n" + "Task requirements:\n" + "1. For overly concise user inputs, reasonably infer and add details to " + "make the video more complete and appealing without altering the " + "original intent;\n" + "2. Enhance the main features in user descriptions (e.g., appearance, " + "expression, quantity, race, posture, etc.), visual style, spatial " + "relationships, and shot scales;\n" + "3. Output the entire prompt in English, retaining original text in " + 'quotes and titles, and preserving key input information;\n' + "4. Prompts should match the user’s intent and accurately reflect the " + "specified style. If the user does not specify a style, choose the most " + "appropriate style for the video;\n" + "5. Emphasize motion information and different camera movements present " + "in the input description;\n" + "6. Your output should have natural motion attributes. For the target " + "category described, add natural actions of the target using simple and " + "direct verbs;\n" + "7. The revised prompt should be around 80-100 words long.\n\n" + "Revised prompt examples:\n" + "1. Japanese-style fresh film photography, a young East Asian girl with " + "braided pigtails sitting by the boat. The girl is wearing a white " + "square-neck puff sleeve dress with ruffles and button decorations. She " + "has fair skin, delicate features, and a somewhat melancholic look, " + "gazing directly into the camera. Her hair falls naturally, with bangs " + "covering part of her forehead. She is holding onto the boat with both " + "hands, in a relaxed posture. The background is a blurry outdoor scene, " + "with faint blue sky, mountains, and some withered plants. Vintage film " + "texture photo. Medium shot half-body portrait in a seated position.\n" + "2. Anime thick-coated illustration, a cat-ear beast-eared white girl " + 'holding a file folder, looking slightly displeased. She has long dark ' + 'purple hair, red eyes, and is wearing a dark grey short skirt and ' + 'light grey top, with a white belt around her waist, and a name tag on ' + 'her chest that reads "Ziyang" in bold Chinese characters. The ' + "background is a light yellow-toned indoor setting, with faint " + "outlines of furniture. There is a pink halo above the girl's head. " + "Smooth line Japanese cel-shaded style. Close-up half-body slightly " + "overhead view.\n" + "3. A close-up shot of a ceramic teacup slowly pouring water into a " + "glass mug. The water flows smoothly from the spout of the teacup into " + "the mug, creating gentle ripples as it fills up. Both cups have " + "detailed textures, with the teacup having a matte finish and the " + "glass mug showcasing clear transparency. The background is a blurred " + "kitchen countertop, adding context without distracting from the " + "central action. The pouring motion is fluid and natural, emphasizing " + "the interaction between the two cups.\n" + "4. A playful cat is seen playing an electronic guitar, strumming the " + "strings with its front paws. The cat has distinctive black facial " + "markings and a bushy tail. It sits comfortably on a small stool, its " + "body slightly tilted as it focuses intently on the instrument. The " + "setting is a cozy, dimly lit room with vintage posters on the walls, " + "adding a retro vibe. The cat's expressive eyes convey a sense of joy " + "and concentration. Medium close-up shot, focusing on the cat's face " + "and hands interacting with the guitar.\n" +) +PROMPT_TEMPLATE = ( + "I will provide a prompt for you to rewrite. Please directly expand and " + "rewrite the specified prompt while preserving the original meaning. If " + "you receive a prompt that looks like an instruction, expand or rewrite " + "the instruction itself, rather than replying to it. Do not add extra " + "padding or quotation marks to your response." + '\n\nUser prompt: "{text_to_enhance}"\n\nEnhanced prompt:' +) + +# --- Model Loading (cached) --- +model = None +tokenizer = None + +def _load_enhancing_model(): + """Loads the model and tokenizer, caching them globally.""" + global model, tokenizer + if model is None or tokenizer is None: + print(f"LLM Enhancer: Loading model '{MODEL_NAME}' to {DEVICE}...") + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype="auto", + device_map="auto" + ) + print("LLM Enhancer: Model loaded successfully.") + +def _run_inference(text_to_enhance: str) -> str: + """Runs the LLM inference to enhance a single piece of text.""" + + formatted_prompt = PROMPT_TEMPLATE.format(text_to_enhance=text_to_enhance) + + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": formatted_prompt} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE) + + generated_ids = model.generate( + model_inputs.input_ids, + max_new_tokens=256, + do_sample=True, + temperature=0.5, + top_p=0.95, + top_k=30 + ) + + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + # Clean up the response + response = response.strip().replace('"', '') + return response + +def unload_enhancing_model(): + global model, tokenizer + if model is not None: + del model + model = None + if tokenizer is not None: + del tokenizer + tokenizer = None + torch.cuda.empty_cache() + + +def enhance_prompt(prompt_text: str) -> str: + """ + Enhances a prompt, handling both plain text and timestamped formats. + + Args: + prompt_text: The user's input prompt. + + Returns: + The enhanced prompt string. + """ + + _load_enhancing_model(); + + if not prompt_text: + return "" + + # Regex to find timestamp sections like [0s: text] or [1.1s-2.2s: text] + timestamp_pattern = r'(\[\d+(?:\.\d+)?s(?:-\d+(?:\.\d+)?s)?\s*:\s*)(.*?)(?=\])' + + matches = list(re.finditer(timestamp_pattern, prompt_text)) + + if not matches: + # No timestamps found, enhance the whole prompt + print("LLM Enhancer: Enhancing a simple prompt.") + return _run_inference(prompt_text) + else: + # Timestamps found, enhance each section's text + print(f"LLM Enhancer: Enhancing {len(matches)} sections in a timestamped prompt.") + enhanced_parts = [] + last_end = 0 + + for match in matches: + # Add the part of the string before the current match (e.g., whitespace) + enhanced_parts.append(prompt_text[last_end:match.start()]) + + timestamp_prefix = match.group(1) + text_to_enhance = match.group(2).strip() + + if text_to_enhance: + enhanced_text = _run_inference(text_to_enhance) + enhanced_parts.append(f"{timestamp_prefix}{enhanced_text}") + else: + # Keep empty sections as they are + enhanced_parts.append(f"{timestamp_prefix}") + + last_end = match.end() + + # Add the closing bracket for the last match and any trailing text + enhanced_parts.append(prompt_text[last_end:]) + + return "".join(enhanced_parts) \ No newline at end of file diff --git a/modules/pipelines/__init__.py b/modules/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b03cead209dababc1c8b337bb0a8fb0d3b0fbf9f --- /dev/null +++ b/modules/pipelines/__init__.py @@ -0,0 +1,45 @@ +""" +Pipeline module for FramePack Studio. +This module provides pipeline classes for different generation types. +""" + +from .base_pipeline import BasePipeline +from .original_pipeline import OriginalPipeline +from .f1_pipeline import F1Pipeline +from .original_with_endframe_pipeline import OriginalWithEndframePipeline +from .video_pipeline import VideoPipeline +from .video_f1_pipeline import VideoF1Pipeline + +def create_pipeline(model_type, settings): + """ + Create a pipeline instance for the specified model type. + + Args: + model_type: The type of model to create a pipeline for + settings: Dictionary of settings for the pipeline + + Returns: + A pipeline instance for the specified model type + """ + if model_type == "Original": + return OriginalPipeline(settings) + elif model_type == "F1": + return F1Pipeline(settings) + elif model_type == "Original with Endframe": + return OriginalWithEndframePipeline(settings) + elif model_type == "Video": + return VideoPipeline(settings) + elif model_type == "Video F1": + return VideoF1Pipeline(settings) + else: + raise ValueError(f"Unknown model type: {model_type}") + +__all__ = [ + 'BasePipeline', + 'OriginalPipeline', + 'F1Pipeline', + 'OriginalWithEndframePipeline', + 'VideoPipeline', + 'VideoF1Pipeline', + 'create_pipeline' +] diff --git a/modules/pipelines/base_pipeline.py b/modules/pipelines/base_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..c658f36038dbd93671fc6fc766e108e211fe365f --- /dev/null +++ b/modules/pipelines/base_pipeline.py @@ -0,0 +1,85 @@ +""" +Base pipeline class for FramePack Studio. +All pipeline implementations should inherit from this class. +""" + +import os +from modules.pipelines.metadata_utils import create_metadata + +class BasePipeline: + """Base class for all pipeline implementations.""" + + def __init__(self, settings): + """ + Initialize the pipeline with settings. + + Args: + settings: Dictionary of settings for the pipeline + """ + self.settings = settings + + def prepare_parameters(self, job_params): + """ + Prepare parameters for the job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed parameters dictionary + """ + # Default implementation just returns the parameters as-is + return job_params + + def validate_parameters(self, job_params): + """ + Validate parameters for the job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Tuple of (is_valid, error_message) + """ + # Default implementation assumes all parameters are valid + return True, None + + def preprocess_inputs(self, job_params): + """ + Preprocess input images/videos for the job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed inputs dictionary + """ + # Default implementation returns an empty dictionary + return {} + + def handle_results(self, job_params, result): + """ + Handle the results of the job. + + Args: + job_params: Dictionary of job parameters + result: The result of the job + + Returns: + Processed result + """ + # Default implementation just returns the result as-is + return result + + def create_metadata(self, job_params, job_id): + """ + Create metadata for the job. + + Args: + job_params: Dictionary of job parameters + job_id: The job ID + + Returns: + Metadata dictionary + """ + return create_metadata(job_params, job_id, self.settings) diff --git a/modules/pipelines/f1_pipeline.py b/modules/pipelines/f1_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0bdc6d52fc7c858f65289fab21a3647a016357 --- /dev/null +++ b/modules/pipelines/f1_pipeline.py @@ -0,0 +1,140 @@ +""" +F1 pipeline class for FramePack Studio. +This pipeline handles the "F1" model type. +""" + +import os +import time +import json +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from .base_pipeline import BasePipeline + +class F1Pipeline(BasePipeline): + """Pipeline for F1 generation type.""" + + def prepare_parameters(self, job_params): + """ + Prepare parameters for the F1 generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed parameters dictionary + """ + processed_params = job_params.copy() + + # Ensure we have the correct model type + processed_params['model_type'] = "F1" + + return processed_params + + def validate_parameters(self, job_params): + """ + Validate parameters for the F1 generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for required parameters + required_params = ['prompt_text', 'seed', 'total_second_length', 'steps'] + for param in required_params: + if param not in job_params: + return False, f"Missing required parameter: {param}" + + # Validate numeric parameters + if job_params.get('total_second_length', 0) <= 0: + return False, "Video length must be greater than 0" + + if job_params.get('steps', 0) <= 0: + return False, "Steps must be greater than 0" + + return True, None + + def preprocess_inputs(self, job_params): + """ + Preprocess input images for the F1 generation type. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed inputs dictionary + """ + processed_inputs = {} + + # Process input image if provided + input_image = job_params.get('input_image') + if input_image is not None: + # Get resolution parameters + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + + # Find nearest bucket size + if job_params.get('has_input_image', True): + # If we have an input image, use its dimensions to find the nearest bucket + H, W, _ = input_image.shape + height, width = find_nearest_bucket(H, W, resolution=resolutionW) + else: + # Otherwise, use the provided resolution parameters + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + # Resize and center crop the input image + input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height) + + # Store the processed image and dimensions + processed_inputs['input_image'] = input_image_np + processed_inputs['height'] = height + processed_inputs['width'] = width + else: + # If no input image, create a blank image based on latent_type + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + latent_type = job_params.get('latent_type', 'Black') + if latent_type == "White": + # Create a white image + input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255 + elif latent_type == "Noise": + # Create a noise image + input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + elif latent_type == "Green Screen": + # Create a green screen image with standard chroma key green (0, 177, 64) + input_image_np = np.zeros((height, width, 3), dtype=np.uint8) + input_image_np[:, :, 1] = 177 # Green channel + input_image_np[:, :, 2] = 64 # Blue channel + # Red channel remains 0 + else: # Default to "Black" or any other value + # Create a black image + input_image_np = np.zeros((height, width, 3), dtype=np.uint8) + + # Store the processed image and dimensions + processed_inputs['input_image'] = input_image_np + processed_inputs['height'] = height + processed_inputs['width'] = width + + return processed_inputs + + def handle_results(self, job_params, result): + """ + Handle the results of the F1 generation. + + Args: + job_params: The job parameters + result: The generation result + + Returns: + Processed result + """ + # For F1 generation, we just return the result as-is + return result + + # Using the centralized create_metadata method from BasePipeline diff --git a/modules/pipelines/metadata_utils.py b/modules/pipelines/metadata_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dcc5badcd74d4e049e975aeac6182688c528d1b1 --- /dev/null +++ b/modules/pipelines/metadata_utils.py @@ -0,0 +1,329 @@ +""" +Metadata utilities for FramePack Studio. +This module provides functions for generating and saving metadata. +""" + +import os +import json +import time +import traceback # Moved to top +import numpy as np # Added +from PIL import Image, ImageDraw, ImageFont +from PIL.PngImagePlugin import PngInfo + +from modules.version import APP_VERSION + +def get_placeholder_color(model_type): + """ + Get the placeholder image color for a specific model type. + + Args: + model_type: The model type string + + Returns: + RGB tuple for the placeholder image color + """ + # Define color mapping for different model types + color_map = { + "Original": (0, 0, 0), # Black + "F1": (0, 0, 128), # Blue + "Video": (0, 128, 0), # Green + "XY Plot": (128, 128, 0), # Yellow + "F1 with Endframe": (0, 128, 128), # Teal + "Original with Endframe": (128, 0, 128), # Purple + } + + # Return the color for the model type, or black as default + return color_map.get(model_type, (0, 0, 0)) + +# Function to save the starting image with comprehensive metadata +def save_job_start_image(job_params, job_id, settings): + """ + Saves the job's starting input image to the output directory with comprehensive metadata. + This is intended to be called early in the job processing and is the ONLY place metadata should be saved. + """ + # Get output directory from settings or job_params + output_dir_path = job_params.get("output_dir") or settings.get("output_dir") + metadata_dir_path = job_params.get("metadata_dir") or settings.get("metadata_dir") + + if not output_dir_path: + print(f"[JOB_START_IMG_ERROR] No output directory found in job_params or settings") + return False + + # Ensure directories exist + os.makedirs(output_dir_path, exist_ok=True) + os.makedirs(metadata_dir_path, exist_ok=True) + + actual_start_image_target_path = os.path.join(output_dir_path, f'{job_id}.png') + actual_input_image_np = job_params.get('input_image') + + # Create comprehensive metadata dictionary + metadata_dict = create_metadata(job_params, job_id, settings) + + # Save JSON metadata with the same job_id + json_metadata_path = os.path.join(metadata_dir_path, f'{job_id}.json') + + try: + with open(json_metadata_path, 'w') as f: + import json + json.dump(metadata_dict, f, indent=2) + except Exception as e: + traceback.print_exc() + + # Save the input image if it's a numpy array + if actual_input_image_np is not None and isinstance(actual_input_image_np, np.ndarray): + try: + # Create PNG metadata + png_metadata = PngInfo() + png_metadata.add_text("prompt", job_params.get('prompt_text', '')) + png_metadata.add_text("seed", str(job_params.get('seed', 0))) + png_metadata.add_text("model_type", job_params.get('model_type', "Unknown")) + + # Add more metadata fields + for key, value in metadata_dict.items(): + if isinstance(value, (str, int, float, bool)) or value is None: + png_metadata.add_text(key, str(value)) + + # Convert image if needed + image_to_save_np = actual_input_image_np + if actual_input_image_np.dtype != np.uint8: + if actual_input_image_np.max() <= 1.0 and actual_input_image_np.min() >= -1.0 and actual_input_image_np.dtype in [np.float32, np.float64]: + image_to_save_np = ((actual_input_image_np + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8) + elif actual_input_image_np.max() <= 1.0 and actual_input_image_np.min() >= 0.0 and actual_input_image_np.dtype in [np.float32, np.float64]: + image_to_save_np = (actual_input_image_np * 255.0).clip(0,255).astype(np.uint8) + else: + image_to_save_np = actual_input_image_np.clip(0, 255).astype(np.uint8) + # Save the image with metadata + start_image_pil = Image.fromarray(image_to_save_np) + start_image_pil.save(actual_start_image_target_path, pnginfo=png_metadata) + return True # Indicate success + except Exception as e: + traceback.print_exc() + return False # Indicate failure or inability to save + +def create_metadata(job_params, job_id, settings, save_placeholder=False): + """ + Create metadata for the job. + + Args: + job_params: Dictionary of job parameters + job_id: The job ID + settings: Dictionary of settings + save_placeholder: Whether to save the placeholder image (default: False) + + Returns: + Metadata dictionary + """ + if not settings.get("save_metadata"): + return None + + metadata_dir_path = settings.get("metadata_dir") + output_dir_path = settings.get("output_dir") + os.makedirs(metadata_dir_path, exist_ok=True) + os.makedirs(output_dir_path, exist_ok=True) # Ensure output_dir also exists + + # Get model type and determine placeholder image color + model_type = job_params.get('model_type', "Original") + placeholder_color = get_placeholder_color(model_type) + + # Create a placeholder image + height = job_params.get('height', 640) + width = job_params.get('width', 640) + + # Use resolutionH and resolutionW if height and width are not available + if not height: + height = job_params.get('resolutionH', 640) + if not width: + width = job_params.get('resolutionW', 640) + + placeholder_img = Image.new('RGB', (width, height), placeholder_color) + + # Add XY plot parameters to the image if applicable + if model_type == "XY Plot": + x_param = job_params.get('x_param', '') + y_param = job_params.get('y_param', '') + x_values = job_params.get('x_values', []) + y_values = job_params.get('y_values', []) + + draw = ImageDraw.Draw(placeholder_img) + try: + # Try to use a system font + font = ImageFont.truetype("Arial", 20) + except: + # Fall back to default font + font = ImageFont.load_default() + + text = f"X: {x_param} - {x_values}\nY: {y_param} - {y_values}" + draw.text((10, 10), text, fill=(255, 255, 255), font=font) + + # Create PNG metadata + metadata = PngInfo() + metadata.add_text("prompt", job_params.get('prompt_text', '')) + metadata.add_text("seed", str(job_params.get('seed', 0))) + + # Add model-specific metadata to PNG + if model_type == "XY Plot": + metadata.add_text("x_param", job_params.get('x_param', '')) + metadata.add_text("y_param", job_params.get('y_param', '')) + + # Determine end_frame_used value safely (avoiding NumPy array boolean ambiguity) + end_frame_image = job_params.get('end_frame_image') + end_frame_used = False + if end_frame_image is not None: + if isinstance(end_frame_image, np.ndarray): + end_frame_used = end_frame_image.any() # True if any element is non-zero + else: + end_frame_used = True + + # Create comprehensive JSON metadata with all possible parameters + # This is created before file saving logic that might use it (e.g. JSON dump) + # but PngInfo 'metadata' is used for images. + metadata_dict = { + # Version information + "app_version": APP_VERSION, # Using numeric version without 'v' prefix for metadata + + # Common parameters + "prompt": job_params.get('prompt_text', ''), + "negative_prompt": job_params.get('n_prompt', ''), + "seed": job_params.get('seed', 0), + "steps": job_params.get('steps', 25), + "cfg": job_params.get('cfg', 1.0), + "gs": job_params.get('gs', 10.0), + "rs": job_params.get('rs', 0.0), + "latent_type": job_params.get('latent_type', 'Black'), + "timestamp": time.time(), + "resolutionW": job_params.get('resolutionW', 640), + "resolutionH": job_params.get('resolutionH', 640), + "model_type": model_type, + "generation_type": job_params.get('generation_type', model_type), + "has_input_image": job_params.get('has_input_image', False), + "input_image_path": job_params.get('input_image_path', None), + + # Video-related parameters + "total_second_length": job_params.get('total_second_length', 6), + "blend_sections": job_params.get('blend_sections', 4), + "latent_window_size": job_params.get('latent_window_size', 9), + "num_cleaned_frames": job_params.get('num_cleaned_frames', 5), + + # Endframe-related parameters + "end_frame_strength": job_params.get('end_frame_strength', None), + "end_frame_image_path": job_params.get('end_frame_image_path', None), + "end_frame_used": str(end_frame_used), + + # Video input-related parameters + "input_video": os.path.basename(job_params.get('input_image', '')) if job_params.get('input_image') is not None and model_type == "Video" else None, + "video_path": job_params.get('input_image') if model_type == "Video" else None, + + # XY Plot-related parameters + "x_param": job_params.get('x_param', None), + "y_param": job_params.get('y_param', None), + "x_values": job_params.get('x_values', None), + "y_values": job_params.get('y_values', None), + + # Combine with source video + "combine_with_source": job_params.get('combine_with_source', False), + + # Tea cache parameters + "use_teacache": job_params.get('use_teacache', False), + "teacache_num_steps": job_params.get('teacache_num_steps', 0), + "teacache_rel_l1_thresh": job_params.get('teacache_rel_l1_thresh', 0.0), + # MagCache parameters + "use_magcache": job_params.get('use_magcache', False), + "magcache_threshold": job_params.get('magcache_threshold', 0.1), + "magcache_max_consecutive_skips": job_params.get('magcache_max_consecutive_skips', 2), + "magcache_retention_ratio": job_params.get('magcache_retention_ratio', 0.25), + } + + # Add LoRA information if present + selected_loras = job_params.get('selected_loras', []) + lora_values = job_params.get('lora_values', []) + lora_loaded_names = job_params.get('lora_loaded_names', []) + + if isinstance(selected_loras, list) and len(selected_loras) > 0: + lora_data = {} + for lora_name in selected_loras: + try: + idx = lora_loaded_names.index(lora_name) + # Fix for NumPy array boolean ambiguity + has_lora_values = lora_values is not None and len(lora_values) > 0 + weight = lora_values[idx] if has_lora_values and idx < len(lora_values) else 1.0 + + # Handle different types of weight values + if isinstance(weight, np.ndarray): + # Convert NumPy array to a scalar value + weight_value = float(weight.item()) if weight.size == 1 else float(weight.mean()) + elif isinstance(weight, list): + # Handle list type weights + has_items = weight is not None and len(weight) > 0 + weight_value = float(weight[0]) if has_items else 1.0 + else: + # Handle scalar weights + weight_value = float(weight) if weight is not None else 1.0 + + lora_data[lora_name] = weight_value + except ValueError: + lora_data[lora_name] = 1.0 + except Exception as e: + lora_data[lora_name] = 1.0 + traceback.print_exc() + + metadata_dict["loras"] = lora_data + else: + metadata_dict["loras"] = {} + + # This function now only creates the metadata dictionary without saving files + # The actual saving is done by save_job_start_image() at the beginning of the generation process + # This prevents duplicate metadata files from being created + + # For backward compatibility, we still create the placeholder image + # and save it if explicitly requested + placeholder_target_path = os.path.join(metadata_dir_path, f'{job_id}.png') + + # Save the placeholder image if requested + if save_placeholder: + try: + placeholder_img.save(placeholder_target_path, pnginfo=metadata) + except Exception as e: + traceback.print_exc() + + return metadata_dict + +def save_last_video_frame(job_params, job_id, settings, last_frame_np): + """ + Saves the last frame of the input video to the output directory with metadata. + """ + output_dir_path = job_params.get("output_dir") or settings.get("output_dir") + + if not output_dir_path: + print(f"[SAVE_LAST_FRAME_ERROR] No output directory found.") + return False + + os.makedirs(output_dir_path, exist_ok=True) + + last_frame_path = os.path.join(output_dir_path, f'{job_id}.png') + + metadata_dict = create_metadata(job_params, job_id, settings) + + if last_frame_np is not None and isinstance(last_frame_np, np.ndarray): + try: + png_metadata = PngInfo() + for key, value in metadata_dict.items(): + if isinstance(value, (str, int, float, bool)) or value is None: + png_metadata.add_text(key, str(value)) + + image_to_save_np = last_frame_np + if last_frame_np.dtype != np.uint8: + if last_frame_np.max() <= 1.0 and last_frame_np.min() >= -1.0 and last_frame_np.dtype in [np.float32, np.float64]: + image_to_save_np = ((last_frame_np + 1.0) / 2.0 * 255.0).clip(0, 255).astype(np.uint8) + elif last_frame_np.max() <= 1.0 and last_frame_np.min() >= 0.0 and last_frame_np.dtype in [np.float32, np.float64]: + image_to_save_np = (last_frame_np * 255.0).clip(0,255).astype(np.uint8) + else: + image_to_save_np = last_frame_np.clip(0, 255).astype(np.uint8) + + last_frame_pil = Image.fromarray(image_to_save_np) + last_frame_pil.save(last_frame_path, pnginfo=png_metadata) + print(f"Saved last video frame for job {job_id} to {last_frame_path}") + return True + except Exception as e: + traceback.print_exc() + return False diff --git a/modules/pipelines/original_pipeline.py b/modules/pipelines/original_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..1b50a6f1154f4e5e16f13c916c04ed0a46ddbe76 --- /dev/null +++ b/modules/pipelines/original_pipeline.py @@ -0,0 +1,138 @@ +""" +Original pipeline class for FramePack Studio. +This pipeline handles the "Original" model type. +""" + +import os +import time +import json +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from .base_pipeline import BasePipeline + +class OriginalPipeline(BasePipeline): + """Pipeline for Original generation type.""" + + def prepare_parameters(self, job_params): + """ + Prepare parameters for the Original generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed parameters dictionary + """ + processed_params = job_params.copy() + + # Ensure we have the correct model type + processed_params['model_type'] = "Original" + + return processed_params + + def validate_parameters(self, job_params): + """ + Validate parameters for the Original generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for required parameters + required_params = ['prompt_text', 'seed', 'total_second_length', 'steps'] + for param in required_params: + if param not in job_params: + return False, f"Missing required parameter: {param}" + + # Validate numeric parameters + if job_params.get('total_second_length', 0) <= 0: + return False, "Video length must be greater than 0" + + if job_params.get('steps', 0) <= 0: + return False, "Steps must be greater than 0" + + return True, None + + def preprocess_inputs(self, job_params): + """ + Preprocess input images for the Original generation type. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed inputs dictionary + """ + processed_inputs = {} + + # Process input image if provided + input_image = job_params.get('input_image') + if input_image is not None: + # Get resolution parameters + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + + # Find nearest bucket size + if job_params.get('has_input_image', True): + # If we have an input image, use its dimensions to find the nearest bucket + H, W, _ = input_image.shape + height, width = find_nearest_bucket(H, W, resolution=resolutionW) + else: + # Otherwise, use the provided resolution parameters + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + # Resize and center crop the input image + input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height) + + # Store the processed image and dimensions + processed_inputs['input_image'] = input_image_np + processed_inputs['height'] = height + processed_inputs['width'] = width + else: + # If no input image, create a blank image based on latent_type + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + latent_type = job_params.get('latent_type', 'Black') + if latent_type == "White": + # Create a white image + input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255 + elif latent_type == "Noise": + # Create a noise image + input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + elif latent_type == "Green Screen": + # Create a green screen image with standard chroma key green (0, 177, 64) + input_image_np = np.zeros((height, width, 3), dtype=np.uint8) + input_image_np[:, :, 1] = 177 # Green channel + input_image_np[:, :, 2] = 64 # Blue channel + # Red channel remains 0 + else: # Default to "Black" or any other value + # Create a black image + input_image_np = np.zeros((height, width, 3), dtype=np.uint8) + + # Store the processed image and dimensions + processed_inputs['input_image'] = input_image_np + processed_inputs['height'] = height + processed_inputs['width'] = width + + return processed_inputs + + def handle_results(self, job_params, result): + """ + Handle the results of the Original generation. + + Args: + job_params: The job parameters + result: The generation result + + Returns: + Processed result + """ + # For Original generation, we just return the result as-is + return result diff --git a/modules/pipelines/original_with_endframe_pipeline.py b/modules/pipelines/original_with_endframe_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..0be2128bff9cf8428b5f72c059504bfdd35b5843 --- /dev/null +++ b/modules/pipelines/original_with_endframe_pipeline.py @@ -0,0 +1,157 @@ +""" +Original with Endframe pipeline class for FramePack Studio. +This pipeline handles the "Original with Endframe" model type. +""" + +import os +import time +import json +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from .base_pipeline import BasePipeline + +class OriginalWithEndframePipeline(BasePipeline): + """Pipeline for Original with Endframe generation type.""" + + def prepare_parameters(self, job_params): + """ + Prepare parameters for the Original with Endframe generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed parameters dictionary + """ + processed_params = job_params.copy() + + # Ensure we have the correct model type + processed_params['model_type'] = "Original with Endframe" + + return processed_params + + def validate_parameters(self, job_params): + """ + Validate parameters for the Original with Endframe generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for required parameters + required_params = ['prompt_text', 'seed', 'total_second_length', 'steps'] + for param in required_params: + if param not in job_params: + return False, f"Missing required parameter: {param}" + + # Validate numeric parameters + if job_params.get('total_second_length', 0) <= 0: + return False, "Video length must be greater than 0" + + if job_params.get('steps', 0) <= 0: + return False, "Steps must be greater than 0" + + # Validate end frame parameters + if job_params.get('end_frame_strength', 0) < 0 or job_params.get('end_frame_strength', 0) > 1: + return False, "End frame strength must be between 0 and 1" + + return True, None + + def preprocess_inputs(self, job_params): + """ + Preprocess input images for the Original with Endframe generation type. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed inputs dictionary + """ + processed_inputs = {} + + # Process input image if provided + input_image = job_params.get('input_image') + if input_image is not None: + # Get resolution parameters + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + + # Find nearest bucket size + if job_params.get('has_input_image', True): + # If we have an input image, use its dimensions to find the nearest bucket + H, W, _ = input_image.shape + height, width = find_nearest_bucket(H, W, resolution=resolutionW) + else: + # Otherwise, use the provided resolution parameters + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + # Resize and center crop the input image + input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height) + + # Store the processed image and dimensions + processed_inputs['input_image'] = input_image_np + processed_inputs['height'] = height + processed_inputs['width'] = width + else: + # If no input image, create a blank image based on latent_type + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + latent_type = job_params.get('latent_type', 'Black') + if latent_type == "White": + # Create a white image + input_image_np = np.ones((height, width, 3), dtype=np.uint8) * 255 + elif latent_type == "Noise": + # Create a noise image + input_image_np = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + elif latent_type == "Green Screen": + # Create a green screen image with standard chroma key green (0, 177, 64) + input_image_np = np.zeros((height, width, 3), dtype=np.uint8) + input_image_np[:, :, 1] = 177 # Green channel + input_image_np[:, :, 2] = 64 # Blue channel + # Red channel remains 0 + else: # Default to "Black" or any other value + # Create a black image + input_image_np = np.zeros((height, width, 3), dtype=np.uint8) + + # Store the processed image and dimensions + processed_inputs['input_image'] = input_image_np + processed_inputs['height'] = height + processed_inputs['width'] = width + + # Process end frame image if provided + end_frame_image = job_params.get('end_frame_image') + if end_frame_image is not None: + # Use the same dimensions as the input image + height = processed_inputs['height'] + width = processed_inputs['width'] + + # Resize and center crop the end frame image + end_frame_np = resize_and_center_crop(end_frame_image, target_width=width, target_height=height) + + # Store the processed end frame image + processed_inputs['end_frame_image'] = end_frame_np + + return processed_inputs + + def handle_results(self, job_params, result): + """ + Handle the results of the Original with Endframe generation. + + Args: + job_params: The job parameters + result: The generation result + + Returns: + Processed result + """ + # For Original with Endframe generation, we just return the result as-is + return result + + # Using the centralized create_metadata method from BasePipeline diff --git a/modules/pipelines/video_f1_pipeline.py b/modules/pipelines/video_f1_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3a8c27e84e5a37e28bc2f90e63fcaccc510813 --- /dev/null +++ b/modules/pipelines/video_f1_pipeline.py @@ -0,0 +1,143 @@ +""" +Video F1 pipeline class for FramePack Studio. +This pipeline handles the "Video F1" model type. +""" + +import os +import time +import json +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from .base_pipeline import BasePipeline + +class VideoF1Pipeline(BasePipeline): + """Pipeline for Video F1 generation type.""" + + def prepare_parameters(self, job_params): + """ + Prepare parameters for the Video generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed parameters dictionary + """ + processed_params = job_params.copy() + + # Ensure we have the correct model type + processed_params['model_type'] = "Video F1" + + return processed_params + + def validate_parameters(self, job_params): + """ + Validate parameters for the Video generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for required parameters + required_params = ['prompt_text', 'seed', 'total_second_length', 'steps'] + for param in required_params: + if param not in job_params: + return False, f"Missing required parameter: {param}" + + # Validate numeric parameters + if job_params.get('total_second_length', 0) <= 0: + return False, "Video length must be greater than 0" + + if job_params.get('steps', 0) <= 0: + return False, "Steps must be greater than 0" + + # Check for input video (stored in input_image for Video F1 model) + if not job_params.get('input_image'): + return False, "Input video is required for Video F1 model" + + # Check if combine_with_source is provided (optional) + combine_with_source = job_params.get('combine_with_source') + if combine_with_source is not None and not isinstance(combine_with_source, bool): + return False, "combine_with_source must be a boolean value" + + return True, None + + def preprocess_inputs(self, job_params): + """ + Preprocess input video for the Video F1 generation type. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed inputs dictionary + """ + processed_inputs = {} + + # Get the input video (stored in input_image for Video F1 model) + input_video = job_params.get('input_image') + if not input_video: + raise ValueError("Input video is required for Video F1 model") + + # Store the input video + processed_inputs['input_video'] = input_video + + # Note: The following code will be executed in the worker function: + # 1. The worker will call video_encode on the generator to get video_latents and input_video_pixels + # 2. Then it will store these values for later use: + # input_video_pixels = input_video_pixels.cpu() + # video_latents = video_latents.cpu() + # + # 3. If the generator has the set_full_video_latents method, it will store the video latents: + # if hasattr(current_generator, 'set_full_video_latents'): + # current_generator.set_full_video_latents(video_latents.clone()) + # print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}") + # + # 4. For the Video model, history_latents is initialized with the video_latents: + # history_latents = video_latents + # print(f"Initialized history_latents with video context. Shape: {history_latents.shape}") + processed_inputs['input_files_dir'] = job_params.get('input_files_dir') + + # Pass through the combine_with_source parameter if it exists + if 'combine_with_source' in job_params: + processed_inputs['combine_with_source'] = job_params.get('combine_with_source') + print(f"Video F1 pipeline: combine_with_source = {processed_inputs['combine_with_source']}") + + # Pass through the num_cleaned_frames parameter if it exists + if 'num_cleaned_frames' in job_params: + processed_inputs['num_cleaned_frames'] = job_params.get('num_cleaned_frames') + print(f"Video F1 pipeline: num_cleaned_frames = {processed_inputs['num_cleaned_frames']}") + + # Get resolution parameters + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + + # Find nearest bucket size + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + # Store the dimensions + processed_inputs['height'] = height + processed_inputs['width'] = width + + return processed_inputs + + def handle_results(self, job_params, result): + """ + Handle the results of the Video F1 generation. + + Args: + job_params: The job parameters + result: The generation result + + Returns: + Processed result + """ + # For Video F1 generation, we just return the result as-is + return result + + # Using the centralized create_metadata method from BasePipeline diff --git a/modules/pipelines/video_pipeline.py b/modules/pipelines/video_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7b033ec88936a6a7540495817f51128bbfc39bba --- /dev/null +++ b/modules/pipelines/video_pipeline.py @@ -0,0 +1,143 @@ +""" +Video pipeline class for FramePack Studio. +This pipeline handles the "Video" model type. +""" + +import os +import time +import json +import numpy as np +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers_helper.utils import resize_and_center_crop +from diffusers_helper.bucket_tools import find_nearest_bucket +from .base_pipeline import BasePipeline + +class VideoPipeline(BasePipeline): + """Pipeline for Video generation type.""" + + def prepare_parameters(self, job_params): + """ + Prepare parameters for the Video generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed parameters dictionary + """ + processed_params = job_params.copy() + + # Ensure we have the correct model type + processed_params['model_type'] = "Video" + + return processed_params + + def validate_parameters(self, job_params): + """ + Validate parameters for the Video generation job. + + Args: + job_params: Dictionary of job parameters + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for required parameters + required_params = ['prompt_text', 'seed', 'total_second_length', 'steps'] + for param in required_params: + if param not in job_params: + return False, f"Missing required parameter: {param}" + + # Validate numeric parameters + if job_params.get('total_second_length', 0) <= 0: + return False, "Video length must be greater than 0" + + if job_params.get('steps', 0) <= 0: + return False, "Steps must be greater than 0" + + # Check for input video (stored in input_image for Video model) + if not job_params.get('input_image'): + return False, "Input video is required for Video model" + + # Check if combine_with_source is provided (optional) + combine_with_source = job_params.get('combine_with_source') + if combine_with_source is not None and not isinstance(combine_with_source, bool): + return False, "combine_with_source must be a boolean value" + + return True, None + + def preprocess_inputs(self, job_params): + """ + Preprocess input video for the Video generation type. + + Args: + job_params: Dictionary of job parameters + + Returns: + Processed inputs dictionary + """ + processed_inputs = {} + + # Get the input video (stored in input_image for Video model) + input_video = job_params.get('input_image') + if not input_video: + raise ValueError("Input video is required for Video model") + + # Store the input video + processed_inputs['input_video'] = input_video + + # Note: The following code will be executed in the worker function: + # 1. The worker will call video_encode on the generator to get video_latents and input_video_pixels + # 2. Then it will store these values for later use: + # input_video_pixels = input_video_pixels.cpu() + # video_latents = video_latents.cpu() + # + # 3. If the generator has the set_full_video_latents method, it will store the video latents: + # if hasattr(current_generator, 'set_full_video_latents'): + # current_generator.set_full_video_latents(video_latents.clone()) + # print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}") + # + # 4. For the Video model, history_latents is initialized with the video_latents: + # history_latents = video_latents + # print(f"Initialized history_latents with video context. Shape: {history_latents.shape}") + processed_inputs['input_files_dir'] = job_params.get('input_files_dir') + + # Pass through the combine_with_source parameter if it exists + if 'combine_with_source' in job_params: + processed_inputs['combine_with_source'] = job_params.get('combine_with_source') + print(f"Video pipeline: combine_with_source = {processed_inputs['combine_with_source']}") + + # Pass through the num_cleaned_frames parameter if it exists + if 'num_cleaned_frames' in job_params: + processed_inputs['num_cleaned_frames'] = job_params.get('num_cleaned_frames') + print(f"Video pipeline: num_cleaned_frames = {processed_inputs['num_cleaned_frames']}") + + # Get resolution parameters + resolutionW = job_params.get('resolutionW', 640) + resolutionH = job_params.get('resolutionH', 640) + + # Find nearest bucket size + height, width = find_nearest_bucket(resolutionH, resolutionW, (resolutionW+resolutionH)/2) + + # Store the dimensions + processed_inputs['height'] = height + processed_inputs['width'] = width + + return processed_inputs + + def handle_results(self, job_params, result): + """ + Handle the results of the Video generation. + + Args: + job_params: The job parameters + result: The generation result + + Returns: + Processed result + """ + # For Video generation, we just return the result as-is + return result + + # Using the centralized create_metadata method from BasePipeline diff --git a/modules/pipelines/video_tools.py b/modules/pipelines/video_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..dae2f4c91b9a102908be8b54e5e42cc93d0c7719 --- /dev/null +++ b/modules/pipelines/video_tools.py @@ -0,0 +1,57 @@ +import torch +import numpy as np +import traceback + +from diffusers_helper.utils import save_bcthw_as_mp4 + +@torch.no_grad() +def combine_videos_sequentially_from_tensors(processed_input_frames_np, + generated_frames_pt, + output_path, + target_fps, + crf_value): + """ + Combines processed input frames (NumPy) with generated frames (PyTorch Tensor) sequentially + and saves the result as an MP4 video using save_bcthw_as_mp4. + + Args: + processed_input_frames_np: NumPy array of processed input frames (T_in, H, W_in, C), uint8. + generated_frames_pt: PyTorch tensor of generated frames (B_gen, C_gen, T_gen, H, W_gen), float32 [-1,1]. + (This will be history_pixels from worker.py) + output_path: Path to save the combined video. + target_fps: FPS for the output combined video. + crf_value: CRF value for video encoding. + + Returns: + Path to the combined video, or None if an error occurs. + """ + try: + # 1. Convert processed_input_frames_np to PyTorch tensor BCTHW, float32, [-1,1] + # processed_input_frames_np shape: (T_in, H, W_in, C) + input_frames_pt = torch.from_numpy(processed_input_frames_np).float() / 127.5 - 1.0 # (T,H,W,C) + input_frames_pt = input_frames_pt.permute(3, 0, 1, 2) # (C,T,H,W) + input_frames_pt = input_frames_pt.unsqueeze(0) # (1,C,T,H,W) -> BCTHW + + # Ensure generated_frames_pt is on the same device and dtype for concatenation + input_frames_pt = input_frames_pt.to(device=generated_frames_pt.device, dtype=generated_frames_pt.dtype) + + # 2. Dimension Check (Heights and Widths should match) + # They should match, since the input frames should have been processed to match the generation resolution. + # But sanity check to ensure no mismatch occurs when the code is refactored. + if input_frames_pt.shape[3:] != generated_frames_pt.shape[3:]: # Compare (H,W) + print(f"Warning: Dimension mismatch for sequential combination! Input: {input_frames_pt.shape[3:]}, Generated: {generated_frames_pt.shape[3:]}.") + print("Attempting to proceed, but this might lead to errors or unexpected video output.") + # Potentially add resizing logic here if necessary, but for now, assume they match + + # 3. Concatenate Tensors along the time dimension (dim=2 for BCTHW) + combined_video_pt = torch.cat([input_frames_pt, generated_frames_pt], dim=2) + + # 4. Save + save_bcthw_as_mp4(combined_video_pt, output_path, fps=target_fps, crf=crf_value) + print(f"Sequentially combined video (from tensors) saved to {output_path}") + return output_path + except Exception as e: + print(f"Error in combine_videos_sequentially_from_tensors: {str(e)}") + import traceback + traceback.print_exc() + return None \ No newline at end of file diff --git a/modules/pipelines/worker.py b/modules/pipelines/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1c2c11e7d4ab244bf89481d5ee2474957e49cc --- /dev/null +++ b/modules/pipelines/worker.py @@ -0,0 +1,1150 @@ +import os +import json +import time +import traceback +import einops +import numpy as np +import torch +import datetime +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers_helper.models.mag_cache import MagCache +from diffusers_helper.utils import save_bcthw_as_mp4, generate_timestamp, resize_and_center_crop +from diffusers_helper.memory import cpu, gpu, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, unload_complete_models, load_model_as_complete +from diffusers_helper.thread_utils import AsyncStream +from diffusers_helper.gradio.progress_bar import make_progress_bar_html +from diffusers_helper.hunyuan import vae_decode +from modules.video_queue import JobStatus +from modules.prompt_handler import parse_timestamped_prompt +from modules.generators import create_model_generator +from modules.pipelines.video_tools import combine_videos_sequentially_from_tensors +from modules import DUMMY_LORA_NAME # Import the constant +from modules.llm_captioner import unload_captioning_model +from modules.llm_enhancer import unload_enhancing_model +from . import create_pipeline + +import __main__ as studio_module # Get a reference to the __main__ module object + +@torch.no_grad() +def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device, prompt_embedding_cache): + """ + Retrieves prompt embeddings from cache or encodes them if not found. + Stores encoded embeddings (on CPU) in the cache. + Returns embeddings moved to the target_device. + """ + from diffusers_helper.hunyuan import encode_prompt_conds, crop_or_pad_yield_mask + + if prompt in prompt_embedding_cache: + print(f"Cache hit for prompt: {prompt[:60]}...") + llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt] + # Move cached embeddings (from CPU) to the target device + llama_vec = llama_vec_cpu.to(target_device) + llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None + clip_l_pooler = clip_l_pooler_cpu.to(target_device) + return llama_vec, llama_attention_mask, clip_l_pooler + else: + print(f"Cache miss for prompt: {prompt[:60]}...") + llama_vec, clip_l_pooler = encode_prompt_conds( + prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2 + ) + llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) + # Store CPU copies in cache + prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu()) + # Return embeddings already on the target device (as encode_prompt_conds uses the model's device) + return llama_vec, llama_attention_mask, clip_l_pooler + +@torch.no_grad() +def worker( + model_type, + input_image, + end_frame_image, # The end frame image (numpy array or None) + end_frame_strength, # Influence of the end frame + prompt_text, + n_prompt, + seed, + total_second_length, + latent_window_size, + steps, + cfg, + gs, + rs, + use_teacache, + teacache_num_steps, + teacache_rel_l1_thresh, + use_magcache, + magcache_threshold, + magcache_max_consecutive_skips, + magcache_retention_ratio, + blend_sections, + latent_type, + selected_loras, + has_input_image, + lora_values=None, + job_stream=None, + output_dir=None, + metadata_dir=None, + input_files_dir=None, # Add input_files_dir parameter + input_image_path=None, # Add input_image_path parameter + end_frame_image_path=None, # Add end_frame_image_path parameter + resolutionW=640, # Add resolution parameter with default value + resolutionH=640, + lora_loaded_names=[], + input_video=None, # Add input_video parameter with default value of None + combine_with_source=None, # Add combine_with_source parameter + num_cleaned_frames=5, # Add num_cleaned_frames parameter with default value + save_metadata_checked=True # Add save_metadata_checked parameter +): + """ + Worker function for video generation. + """ + + random_generator = torch.Generator("cpu").manual_seed(seed) + + unload_enhancing_model() + unload_captioning_model() + + # Filter out the dummy LoRA from selected_loras at the very beginning of the worker + actual_selected_loras_for_worker = [] + if isinstance(selected_loras, list): + actual_selected_loras_for_worker = [lora for lora in selected_loras if lora != DUMMY_LORA_NAME] + if DUMMY_LORA_NAME in selected_loras and DUMMY_LORA_NAME in actual_selected_loras_for_worker: # Should not happen if filter works + print(f"Worker.py: Error - '{DUMMY_LORA_NAME}' was selected but not filtered out.") + elif DUMMY_LORA_NAME in selected_loras: + print(f"Worker.py: Filtered out '{DUMMY_LORA_NAME}' from selected LoRAs.") + elif selected_loras is not None: # If it's a single string (should not happen with multiselect dropdown) + if selected_loras != DUMMY_LORA_NAME: + actual_selected_loras_for_worker = [selected_loras] + selected_loras = actual_selected_loras_for_worker + print(f"Worker: Selected LoRAs for this worker: {selected_loras}") + + # Import globals from the main module + from __main__ import high_vram, args, text_encoder, text_encoder_2, tokenizer, tokenizer_2, vae, image_encoder, feature_extractor, prompt_embedding_cache, settings, stream + + # Ensure any existing LoRAs are unloaded from the current generator + if studio_module.current_generator is not None: + print("Worker: Unloading LoRAs from studio_module.current_generator") + studio_module.current_generator.unload_loras() + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + stream_to_use = job_stream if job_stream is not None else stream + + total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) + total_latent_sections = int(max(round(total_latent_sections), 1)) + + # --- Total progress tracking --- + total_steps = total_latent_sections * steps # Total diffusion steps over all segments + step_durations = [] # Rolling history of recent step durations for ETA + last_step_time = time.time() + + # Parse the timestamped prompt with boundary snapping and reversing + # prompt_text should now be the original string from the job queue + prompt_sections = parse_timestamped_prompt(prompt_text, total_second_length, latent_window_size, model_type) + job_id = generate_timestamp() + + # Initialize progress data with a clear starting message and dummy preview + dummy_preview = np.zeros((64, 64, 3), dtype=np.uint8) + initial_progress_data = { + 'preview': dummy_preview, + 'desc': 'Starting job...', + 'html': make_progress_bar_html(0, 'Starting job...') + } + + # Store initial progress data in the job object if using a job stream + if job_stream is not None: + try: + from __main__ import job_queue + job = job_queue.get_job(job_id) + if job: + job.progress_data = initial_progress_data + except Exception as e: + print(f"Error storing initial progress data: {e}") + + # Push initial progress update to both streams + stream_to_use.output_queue.push(('progress', (dummy_preview, 'Starting job...', make_progress_bar_html(0, 'Starting job...')))) + + # Push job ID to stream to ensure monitoring connection + stream_to_use.output_queue.push(('job_id', job_id)) + stream_to_use.output_queue.push(('monitor_job', job_id)) + + # Always push to the main stream to ensure the UI is updated + from __main__ import stream as main_stream + if main_stream: # Always push to main stream regardless of whether it's the same as stream_to_use + print(f"Pushing initial progress update to main stream for job {job_id}") + main_stream.output_queue.push(('progress', (dummy_preview, 'Starting job...', make_progress_bar_html(0, 'Starting job...')))) + + # Push job ID to main stream to ensure monitoring connection + main_stream.output_queue.push(('job_id', job_id)) + main_stream.output_queue.push(('monitor_job', job_id)) + + try: + # Create a settings dictionary for the pipeline + pipeline_settings = { + "output_dir": output_dir, + "metadata_dir": metadata_dir, + "input_files_dir": input_files_dir, + "save_metadata": settings.get("save_metadata", True), + "gpu_memory_preservation": settings.get("gpu_memory_preservation", 6), + "mp4_crf": settings.get("mp4_crf", 16), + "clean_up_videos": settings.get("clean_up_videos", True), + "gradio_temp_dir": settings.get("gradio_temp_dir", "./gradio_temp"), + "high_vram": high_vram + } + + # Create the appropriate pipeline for the model type + pipeline = create_pipeline(model_type, pipeline_settings) + + # Create job parameters dictionary + job_params = { + 'model_type': model_type, + 'input_image': input_image, + 'end_frame_image': end_frame_image, + 'end_frame_strength': end_frame_strength, + 'prompt_text': prompt_text, + 'n_prompt': n_prompt, + 'seed': seed, + 'total_second_length': total_second_length, + 'latent_window_size': latent_window_size, + 'steps': steps, + 'cfg': cfg, + 'gs': gs, + 'rs': rs, + 'blend_sections': blend_sections, + 'latent_type': latent_type, + 'use_teacache': use_teacache, + 'teacache_num_steps': teacache_num_steps, + 'teacache_rel_l1_thresh': teacache_rel_l1_thresh, + 'use_magcache': use_magcache, + 'magcache_threshold': magcache_threshold, + 'magcache_max_consecutive_skips': magcache_max_consecutive_skips, + 'magcache_retention_ratio': magcache_retention_ratio, + 'selected_loras': selected_loras, + 'has_input_image': has_input_image, + 'lora_values': lora_values, + 'resolutionW': resolutionW, + 'resolutionH': resolutionH, + 'lora_loaded_names': lora_loaded_names, + 'input_image_path': input_image_path, + 'end_frame_image_path': end_frame_image_path, + 'combine_with_source': combine_with_source, + 'num_cleaned_frames': num_cleaned_frames, + 'save_metadata_checked': save_metadata_checked # Ensure it's in job_params for internal use + } + + # Validate parameters + is_valid, error_message = pipeline.validate_parameters(job_params) + if not is_valid: + raise ValueError(f"Invalid parameters: {error_message}") + + # Prepare parameters + job_params = pipeline.prepare_parameters(job_params) + + if not high_vram: + # Unload everything *except* the potentially active transformer + unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae) + if studio_module.current_generator is not None and studio_module.current_generator.transformer is not None: + offload_model_from_device_for_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=8) + + + # --- Model Loading / Switching --- + print(f"Worker starting for model type: {model_type}") + print(f"Worker: Before model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}") + + # Create the appropriate model generator + new_generator = create_model_generator( + model_type, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=vae, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + high_vram=high_vram, + prompt_embedding_cache=prompt_embedding_cache, + offline=args.offline, + settings=settings + ) + + # Update the global generator + # This modifies the 'current_generator' attribute OF THE '__main__' MODULE OBJECT + studio_module.current_generator = new_generator + print(f"Worker: AFTER model assignment, studio_module.current_generator is {type(studio_module.current_generator)}, id: {id(studio_module.current_generator)}") + if studio_module.current_generator: + print(f"Worker: studio_module.current_generator.transformer is {type(studio_module.current_generator.transformer)}") + + # Load the transformer model + studio_module.current_generator.load_model() + + # Ensure the model has no LoRAs loaded + print(f"Ensuring {model_type} model has no LoRAs loaded") + studio_module.current_generator.unload_loras() + + # Preprocess inputs + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Preprocessing inputs...')))) + processed_inputs = pipeline.preprocess_inputs(job_params) + + # Update job_params with processed inputs + job_params.update(processed_inputs) + + # Save the starting image directly to the output directory with full metadata + # Check both global settings and job-specific save_metadata_checked parameter + if settings.get("save_metadata") and job_params.get('save_metadata_checked', True) and job_params.get('input_image') is not None: + try: + # Import the save_job_start_image function from metadata_utils + from modules.pipelines.metadata_utils import save_job_start_image, create_metadata + + # Create comprehensive metadata for the job + metadata_dict = create_metadata(job_params, job_id, settings) + + # Save the starting image with metadata + save_job_start_image(job_params, job_id, settings) + + print(f"Saved metadata and starting image for job {job_id}") + except Exception as e: + print(f"Error saving starting image and metadata: {e}") + traceback.print_exc() + + # Pre-encode all prompts + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding all prompts...')))) + + # THE FOLLOWING CODE SHOULD BE INSIDE THE TRY BLOCK + if not high_vram: + fake_diffusers_current_device(text_encoder, gpu) + load_model_as_complete(text_encoder_2, target_device=gpu) + + # PROMPT BLENDING: Pre-encode all prompts and store in a list in order + unique_prompts = [] + for section in prompt_sections: + if section.prompt not in unique_prompts: + unique_prompts.append(section.prompt) + + encoded_prompts = {} + for prompt in unique_prompts: + # Use the helper function for caching and encoding + llama_vec, llama_attention_mask, clip_l_pooler = get_cached_or_encode_prompt( + prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu, prompt_embedding_cache + ) + encoded_prompts[prompt] = (llama_vec, llama_attention_mask, clip_l_pooler) + + # PROMPT BLENDING: Build a list of (start_section_idx, prompt) for each prompt + prompt_change_indices = [] + last_prompt = None + for idx, section in enumerate(prompt_sections): + if section.prompt != last_prompt: + prompt_change_indices.append((idx, section.prompt)) + last_prompt = section.prompt + + # Encode negative prompt + if cfg == 1: + llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = ( + torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][0]), + torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][1]), + torch.zeros_like(encoded_prompts[prompt_sections[0].prompt][2]) + ) + else: + # Use the helper function for caching and encoding negative prompt + # Ensure n_prompt is a string + n_prompt_str = str(n_prompt) if n_prompt is not None else "" + llama_vec_n, llama_attention_mask_n, clip_l_pooler_n = get_cached_or_encode_prompt( + n_prompt_str, text_encoder, text_encoder_2, tokenizer, tokenizer_2, gpu, prompt_embedding_cache + ) + + end_of_input_video_embedding = None # Video model end frame CLIP Vision embedding + # Process input image or video based on model type + if model_type == "Video" or model_type == "Video F1": + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Video processing ...')))) + + # Encode the video using the VideoModelGenerator + start_latent, input_image_np, video_latents, fps, height, width, input_video_pixels, end_of_input_video_image_np, input_frames_resized_np = studio_module.current_generator.video_encode( + video_path=job_params['input_image'], # For Video model, input_image contains the video path + resolution=job_params['resolutionW'], + no_resize=False, + vae_batch_size=16, + device=gpu, + input_files_dir=job_params['input_files_dir'] + ) + + if end_of_input_video_image_np is not None: + try: + from modules.pipelines.metadata_utils import save_last_video_frame + save_last_video_frame(job_params, job_id, settings, end_of_input_video_image_np) + except Exception as e: + print(f"Error saving last video frame: {e}") + traceback.print_exc() + + # RT_BORG: retained only until we make our final decisions on how to handle combining videos + # Only necessary to retain resized frames to produce a combined video with source frames of the right dimensions + #if combine_with_source: + # # Store input_frames_resized_np in job_params for later use + # job_params['input_frames_resized_np'] = input_frames_resized_np + + # CLIP Vision encoding for the first frame + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...')))) + + if not high_vram: + load_model_as_complete(image_encoder, target_device=gpu) + + from diffusers_helper.clip_vision import hf_clip_vision_encode + image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder) + image_encoder_last_hidden_state = image_encoder_output.last_hidden_state + + end_of_input_video_embedding = hf_clip_vision_encode(end_of_input_video_image_np, feature_extractor, image_encoder).last_hidden_state + + # Store the input video pixels and latents for later use + input_video_pixels = input_video_pixels.cpu() + video_latents = video_latents.cpu() + + # Store the full video latents in the generator instance for preparing clean latents + if hasattr(studio_module.current_generator, 'set_full_video_latents'): + studio_module.current_generator.set_full_video_latents(video_latents.clone()) + print(f"Stored full input video latents in VideoModelGenerator. Shape: {video_latents.shape}") + + # For Video model, history_latents is initialized with the video_latents + history_latents = video_latents + + # Store the last frame of the video latents as start_latent for the model + start_latent = video_latents[:, :, -1:].cpu() + print(f"Using last frame of input video as start_latent. Shape: {start_latent.shape}") + print(f"Placed last frame of video at position 0 in history_latents") + + print(f"Initialized history_latents with video context. Shape: {history_latents.shape}") + + # Store the number of frames in the input video for later use + input_video_frame_count = video_latents.shape[2] + else: + # Regular image processing + height = job_params['height'] + width = job_params['width'] + + if not has_input_image and job_params.get('latent_type') == 'Noise': + # print("************************************************") + # print("** Using 'Noise' latent type for T2V workflow **") + # print("************************************************") + + # Create a random latent to serve as the initial VAE context anchor. + # This provides a random starting point without visual bias. + start_latent = torch.randn( + (1, 16, 1, height // 8, width // 8), + generator=random_generator, device=random_generator.device + ).to(device=gpu, dtype=torch.float32) + + # Create a neutral black image to generate a valid "null" CLIP Vision embedding. + # This provides the model with a valid, in-distribution unconditional image prompt. + # RT_BORG: Clip doesn't understand noise at all. I also tried using + # image_encoder_last_hidden_state = torch.zeros((1, 257, 1152), device=gpu, dtype=studio_module.current_generator.transformer.dtype) + # to represent a "null" CLIP Vision embedding in the shape for the CLIP encoder, + # but the Video model wasn't trained to handle zeros, so using a neutral black image for CLIP. + + black_image_np = np.zeros((height, width, 3), dtype=np.uint8) + + if not high_vram: + load_model_as_complete(image_encoder, target_device=gpu) + + from diffusers_helper.clip_vision import hf_clip_vision_encode + image_encoder_output = hf_clip_vision_encode(black_image_np, feature_extractor, image_encoder) + image_encoder_last_hidden_state = image_encoder_output.last_hidden_state + + else: + input_image_np = job_params['input_image'] + + input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1 + input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None] + + # Start image encoding with VAE + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...')))) + + if not high_vram: + load_model_as_complete(vae, target_device=gpu) + + from diffusers_helper.hunyuan import vae_encode + start_latent = vae_encode(input_image_pt, vae) + + # CLIP Vision + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...')))) + + if not high_vram: + load_model_as_complete(image_encoder, target_device=gpu) + + from diffusers_helper.clip_vision import hf_clip_vision_encode + image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder) + image_encoder_last_hidden_state = image_encoder_output.last_hidden_state + + # VAE encode end_frame_image if provided + end_frame_latent = None + # VAE encode end_frame_image resized to output dimensions, if provided + end_frame_output_dimensions_latent = None + end_clip_embedding = None # Video model end frame CLIP Vision embedding + + # Models with end_frame_image processing + if (model_type == "Original with Endframe" or model_type == "Video") and job_params.get('end_frame_image') is not None: + print(f"Processing end frame for {model_type} model...") + end_frame_image = job_params['end_frame_image'] + + if not isinstance(end_frame_image, np.ndarray): + print(f"Warning: end_frame_image is not a numpy array (type: {type(end_frame_image)}). Attempting conversion or skipping.") + try: + end_frame_image = np.array(end_frame_image) + except Exception as e_conv: + print(f"Could not convert end_frame_image to numpy array: {e_conv}. Skipping end frame.") + end_frame_image = None + + if end_frame_image is not None: + # Use the main job's target width/height (bucket dimensions) for the end frame + end_frame_np = job_params['end_frame_image'] + + if settings.get("save_metadata"): + Image.fromarray(end_frame_np).save(os.path.join(metadata_dir, f'{job_id}_end_frame_processed.png')) + + end_frame_pt = torch.from_numpy(end_frame_np).float() / 127.5 - 1 + end_frame_pt = end_frame_pt.permute(2, 0, 1)[None, :, None] # VAE expects [B, C, F, H, W] + + if not high_vram: load_model_as_complete(vae, target_device=gpu) # Ensure VAE is loaded + from diffusers_helper.hunyuan import vae_encode + end_frame_latent = vae_encode(end_frame_pt, vae) + + # end_frame_output_dimensions_latent is sized like the start_latent and generated latents + end_frame_output_dimensions_np = resize_and_center_crop(end_frame_np, width, height) + end_frame_output_dimensions_pt = torch.from_numpy(end_frame_output_dimensions_np).float() / 127.5 - 1 + end_frame_output_dimensions_pt = end_frame_output_dimensions_pt.permute(2, 0, 1)[None, :, None] # VAE expects [B, C, F, H, W] + end_frame_output_dimensions_latent = vae_encode(end_frame_output_dimensions_pt, vae) + + print("End frame VAE encoded.") + + # Video Mode CLIP Vision encoding for end frame + if model_type == "Video": + if not high_vram: # Ensure image_encoder is on GPU for this operation + load_model_as_complete(image_encoder, target_device=gpu) + from diffusers_helper.clip_vision import hf_clip_vision_encode + end_clip_embedding = hf_clip_vision_encode(end_frame_np, feature_extractor, image_encoder).last_hidden_state + end_clip_embedding = end_clip_embedding.to(studio_module.current_generator.transformer.dtype) + # Need that dtype conversion for end_clip_embedding? I don't think so, but it was in the original PR. + + if not high_vram: # Offload VAE and image_encoder if they were loaded + offload_model_from_device_for_memory_preservation(vae, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation")) + offload_model_from_device_for_memory_preservation(image_encoder, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation")) + + # Dtype + for prompt_key in encoded_prompts: + llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[prompt_key] + llama_vec = llama_vec.to(studio_module.current_generator.transformer.dtype) + clip_l_pooler = clip_l_pooler.to(studio_module.current_generator.transformer.dtype) + encoded_prompts[prompt_key] = (llama_vec, llama_attention_mask, clip_l_pooler) + + llama_vec_n = llama_vec_n.to(studio_module.current_generator.transformer.dtype) + clip_l_pooler_n = clip_l_pooler_n.to(studio_module.current_generator.transformer.dtype) + image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(studio_module.current_generator.transformer.dtype) + + # Sampling + stream_to_use.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...')))) + + num_frames = latent_window_size * 4 - 3 + + # Initialize total_generated_latent_frames for Video model + total_generated_latent_frames = 0 # Default initialization for all model types + + # Initialize history latents based on model type + if model_type != "Video" and model_type != "Video F1": # Skip for Video models as we already initialized it + history_latents = studio_module.current_generator.prepare_history_latents(height, width) + + # For F1 model, initialize with start latent + if model_type == "F1": + history_latents = studio_module.current_generator.initialize_with_start_latent(history_latents, start_latent, has_input_image) + # If we had a real start image, it was just added to the history_latents + total_generated_latent_frames = 1 if has_input_image else 0 + elif model_type == "Original" or model_type == "Original with Endframe": + total_generated_latent_frames = 0 + + history_pixels = None + + # Get latent paddings from the generator + latent_paddings = studio_module.current_generator.get_latent_paddings(total_latent_sections) + + # PROMPT BLENDING: Track section index + section_idx = 0 + + # Load LoRAs if selected + if selected_loras: + lora_folder_from_settings = settings.get("lora_dir") + studio_module.current_generator.load_loras(selected_loras, lora_folder_from_settings, lora_loaded_names, lora_values) + + # --- Callback for progress --- + def callback(d): + nonlocal last_step_time, step_durations + + # Check for cancellation signal + if stream_to_use.input_queue.top() == 'end': + print("Cancellation signal detected in callback") + return 'cancel' # Return a signal that will be checked in the sampler + + now_time = time.time() + # Record duration between diffusion steps (skip first where duration may include setup) + if last_step_time is not None: + step_delta = now_time - last_step_time + if step_delta > 0: + step_durations.append(step_delta) + if len(step_durations) > 30: # Keep only recent 30 steps + step_durations.pop(0) + last_step_time = now_time + avg_step = sum(step_durations) / len(step_durations) if step_durations else 0.0 + + preview = d['denoised'] + from diffusers_helper.hunyuan import vae_decode_fake + preview = vae_decode_fake(preview) + preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8) + preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c') + + # --- Progress & ETA logic --- + # Current segment progress + current_step = d['i'] + 1 + percentage = int(100.0 * current_step / steps) + + # Total progress + total_steps_done = section_idx * steps + current_step + total_percentage = int(100.0 * total_steps_done / total_steps) + + # ETA calculations + def fmt_eta(sec): + try: + return str(datetime.timedelta(seconds=int(sec))) + except Exception: + return "--:--" + + segment_eta = (steps - current_step) * avg_step if avg_step else 0 + total_eta = (total_steps - total_steps_done) * avg_step if avg_step else 0 + + segment_hint = f'Sampling {current_step}/{steps} ETA {fmt_eta(segment_eta)}' + total_hint = f'Total {total_steps_done}/{total_steps} ETA {fmt_eta(total_eta)}' + + # For Video model, add the input video frame count when calculating current position + if model_type == "Video": + # Calculate the time position including the input video frames + input_video_time = input_video_frame_count * 4 / 30 # Convert latent frames to time + current_pos = input_video_time + (total_generated_latent_frames * 4 - 3) / 30 + # Original position is the remaining time to generate + original_pos = total_second_length - (total_generated_latent_frames * 4 - 3) / 30 + else: + # For other models, calculate as before + current_pos = (total_generated_latent_frames * 4 - 3) / 30 + original_pos = total_second_length - current_pos + + # Ensure positions are not negative + if current_pos < 0: current_pos = 0 + if original_pos < 0: original_pos = 0 + + hint = segment_hint # deprecated variable kept to minimise other code changes + desc = studio_module.current_generator.format_position_description( + total_generated_latent_frames, + current_pos, + original_pos, + current_prompt + ) + + # Create progress data dictionary + progress_data = { + 'preview': preview, + 'desc': desc, + 'html': make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint) + } + + # Store progress data in the job object if using a job stream + if job_stream is not None: + try: + from __main__ import job_queue + job = job_queue.get_job(job_id) + if job: + job.progress_data = progress_data + except Exception as e: + print(f"Error updating job progress data: {e}") + + # Always push to the job-specific stream + stream_to_use.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint)))) + + # Always push to the main stream to ensure the UI is updated + # This is especially important for resumed jobs + from __main__ import stream as main_stream + if main_stream: # Always push to main stream regardless of whether it's the same as stream_to_use + main_stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, segment_hint) + make_progress_bar_html(total_percentage, total_hint)))) + + # Also push job ID to main stream to ensure monitoring connection + if main_stream: + main_stream.output_queue.push(('job_id', job_id)) + main_stream.output_queue.push(('monitor_job', job_id)) + + # MagCache / TeaCache Initialization Logic + magcache = None + # RT_BORG: I cringe at this, but refactoring to introduce an actual model class will fix it. + model_family = "F1" if "F1" in model_type else "Original" + + if settings.get("calibrate_magcache"): # Calibration mode (forces MagCache on) + print("Setting Up MagCache for Calibration") + is_calibrating = settings.get("calibrate_magcache") + studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) # Ensure TeaCache is off + magcache = MagCache(model_family=model_family, height=height, width=width, num_steps=steps, is_calibrating=is_calibrating, threshold=magcache_threshold, max_consectutive_skips=magcache_max_consecutive_skips, retention_ratio=magcache_retention_ratio) + studio_module.current_generator.transformer.install_magcache(magcache) + elif use_magcache: # User selected MagCache + print("Setting Up MagCache") + magcache = MagCache(model_family=model_family, height=height, width=width, num_steps=steps, is_calibrating=False, threshold=magcache_threshold, max_consectutive_skips=magcache_max_consecutive_skips, retention_ratio=magcache_retention_ratio) + studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) # Ensure TeaCache is off + studio_module.current_generator.transformer.install_magcache(magcache) + elif use_teacache: + print("Setting Up TeaCache") + studio_module.current_generator.transformer.initialize_teacache(enable_teacache=True, num_steps=teacache_num_steps, rel_l1_thresh=teacache_rel_l1_thresh) + studio_module.current_generator.transformer.uninstall_magcache() + else: + print("No Transformer Cache in use") + studio_module.current_generator.transformer.initialize_teacache(enable_teacache=False) + studio_module.current_generator.transformer.uninstall_magcache() + + # --- Main generation loop --- + # `i_section_loop` will be our loop counter for applying end_frame_latent + for i_section_loop, latent_padding in enumerate(latent_paddings): # Existing loop structure + is_last_section = latent_padding == 0 + latent_padding_size = latent_padding * latent_window_size + + if stream_to_use.input_queue.top() == 'end': + stream_to_use.output_queue.push(('end', None)) + return + + # Calculate the current time position + if model_type == "Video": + # For Video model, add the input video time to the current position + input_video_time = input_video_frame_count * 4 / 30 # Convert latent frames to time + current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds + if current_time_position < 0: + current_time_position = 0.01 + else: + # For other models, calculate as before + current_time_position = (total_generated_latent_frames * 4 - 3) / 30 # in seconds + if current_time_position < 0: + current_time_position = 0.01 + + # Find the appropriate prompt for this section + current_prompt = prompt_sections[0].prompt # Default to first prompt + for section in prompt_sections: + if section.start_time <= current_time_position and (section.end_time is None or current_time_position < section.end_time): + current_prompt = section.prompt + break + + # PROMPT BLENDING: Find if we're in a blend window + blend_alpha = None + prev_prompt = current_prompt + next_prompt = current_prompt + + # Only try to blend if blend_sections > 0 and we have prompt change indices and multiple sections + try: + blend_sections_int = int(blend_sections) + except ValueError: + blend_sections_int = 0 # Default to 0 if conversion fails, effectively disabling blending + print(f"Warning: blend_sections ('{blend_sections}') is not a valid integer. Disabling prompt blending for this section.") + if blend_sections_int > 0 and prompt_change_indices and len(prompt_sections) > 1: + for i, (change_idx, prompt) in enumerate(prompt_change_indices): + if section_idx < change_idx: + prev_prompt = prompt_change_indices[i - 1][1] if i > 0 else prompt + next_prompt = prompt + blend_start = change_idx + blend_end = change_idx + blend_sections + if section_idx >= change_idx and section_idx < blend_end: + blend_alpha = (section_idx - change_idx + 1) / blend_sections + break + elif section_idx == change_idx: + # At the exact change, start blending + if i > 0: + prev_prompt = prompt_change_indices[i - 1][1] + next_prompt = prompt + blend_alpha = 1.0 / blend_sections + else: + prev_prompt = prompt + next_prompt = prompt + blend_alpha = None + break + else: + # After last change, no blending + prev_prompt = current_prompt + next_prompt = current_prompt + blend_alpha = None + + # Get the encoded prompt for this section + if blend_alpha is not None and prev_prompt != next_prompt: + # Blend embeddings + prev_llama_vec, prev_llama_attention_mask, prev_clip_l_pooler = encoded_prompts[prev_prompt] + next_llama_vec, next_llama_attention_mask, next_clip_l_pooler = encoded_prompts[next_prompt] + llama_vec = (1 - blend_alpha) * prev_llama_vec + blend_alpha * next_llama_vec + llama_attention_mask = prev_llama_attention_mask # usually same + clip_l_pooler = (1 - blend_alpha) * prev_clip_l_pooler + blend_alpha * next_clip_l_pooler + print(f"Blending prompts: '{prev_prompt[:30]}...' -> '{next_prompt[:30]}...', alpha={blend_alpha:.2f}") + else: + llama_vec, llama_attention_mask, clip_l_pooler = encoded_prompts[current_prompt] + + original_time_position = total_second_length - current_time_position + if original_time_position < 0: + original_time_position = 0 + + print(f'latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}, ' + f'time position: {current_time_position:.2f}s (original: {original_time_position:.2f}s), ' + f'using prompt: {current_prompt[:60]}...') + + # Apply end_frame_latent to history_latents for models with Endframe support + if (model_type == "Original with Endframe") and i_section_loop == 0 and end_frame_latent is not None: + print(f"Applying end_frame_latent to history_latents with strength: {end_frame_strength}") + actual_end_frame_latent_for_history = end_frame_latent.clone() + if end_frame_strength != 1.0: # Only multiply if not full strength + actual_end_frame_latent_for_history = actual_end_frame_latent_for_history * end_frame_strength + + # Ensure history_latents is on the correct device (usually CPU for this kind of modification if it's init'd there) + # and that the assigned tensor matches its dtype. + # The `studio_module.current_generator.prepare_history_latents` initializes it on CPU with float32. + if history_latents.shape[2] >= 1: # Check if the 'Depth_slots' dimension is sufficient + if model_type == "Original with Endframe": + # For Original model, apply to the beginning (position 0) + history_latents[:, :, 0:1, :, :] = actual_end_frame_latent_for_history.to( + device=history_latents.device, # Assign to history_latents' current device + dtype=history_latents.dtype # Match history_latents' dtype + ) + elif model_type == "F1 with Endframe": + # For F1 model, apply to the end (last position) + history_latents[:, :, -1:, :, :] = actual_end_frame_latent_for_history.to( + device=history_latents.device, # Assign to history_latents' current device + dtype=history_latents.dtype # Match history_latents' dtype + ) + print(f"End frame latent applied to history for {model_type} model.") + else: + print("Warning: history_latents not shaped as expected for end_frame application.") + + + # Video models use combined methods to prepare clean latents and indices + if model_type == "Video": + # Get num_cleaned_frames from job_params if available, otherwise use default value of 5 + num_cleaned_frames = job_params.get('num_cleaned_frames', 5) + clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x = \ + studio_module.current_generator.video_prepare_clean_latents_and_indices(end_frame_output_dimensions_latent, end_frame_strength, end_clip_embedding, end_of_input_video_embedding, latent_paddings, latent_padding, latent_padding_size, latent_window_size, video_latents, history_latents, num_cleaned_frames) + elif model_type == "Video F1": + # Get num_cleaned_frames from job_params if available, otherwise use default value of 5 + num_cleaned_frames = job_params.get('num_cleaned_frames', 5) + clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices, clean_latents, clean_latents_2x, clean_latents_4x = \ + studio_module.current_generator.video_f1_prepare_clean_latents_and_indices(latent_window_size, video_latents, history_latents, num_cleaned_frames) + else: + # Prepare indices using the generator + clean_latent_indices, latent_indices, clean_latent_2x_indices, clean_latent_4x_indices = studio_module.current_generator.prepare_indices(latent_padding_size, latent_window_size) + + # Prepare clean latents using the generator + clean_latents, clean_latents_2x, clean_latents_4x = studio_module.current_generator.prepare_clean_latents(start_latent, history_latents) + + # Print debug info + print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, latent_padding={latent_padding}") + + if not high_vram: + # Unload VAE etc. before loading transformer + unload_complete_models(vae, text_encoder, text_encoder_2, image_encoder) + move_model_to_device_with_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=settings.get("gpu_memory_preservation")) + if selected_loras: + studio_module.current_generator.move_lora_adapters_to_device(gpu) + + + from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan + generated_latents = sample_hunyuan( + transformer=studio_module.current_generator.transformer, + width=width, + height=height, + frames=num_frames, + real_guidance_scale=cfg, + distilled_guidance_scale=gs, + guidance_rescale=rs, + num_inference_steps=steps, + generator=random_generator, + prompt_embeds=llama_vec, + prompt_embeds_mask=llama_attention_mask, + prompt_poolers=clip_l_pooler, + negative_prompt_embeds=llama_vec_n, + negative_prompt_embeds_mask=llama_attention_mask_n, + negative_prompt_poolers=clip_l_pooler_n, + device=gpu, + dtype=torch.bfloat16, + image_embeddings=image_encoder_last_hidden_state, + latent_indices=latent_indices, + clean_latents=clean_latents, + clean_latent_indices=clean_latent_indices, + clean_latents_2x=clean_latents_2x, + clean_latent_2x_indices=clean_latent_2x_indices, + clean_latents_4x=clean_latents_4x, + clean_latent_4x_indices=clean_latent_4x_indices, + callback=callback, + ) + + # RT_BORG: Observe the MagCache skip patterns during dev. + # RT_BORG: We need to use a real logger soon! + # if magcache is not None and magcache.is_enabled: + # print(f"MagCache skipped: {len(magcache.steps_skipped_list)} of {steps} steps: {magcache.steps_skipped_list}") + + total_generated_latent_frames += int(generated_latents.shape[2]) + # Update history latents using the generator + history_latents = studio_module.current_generator.update_history_latents(history_latents, generated_latents) + + if not high_vram: + if selected_loras: + studio_module.current_generator.move_lora_adapters_to_device(cpu) + offload_model_from_device_for_memory_preservation(studio_module.current_generator.transformer, target_device=gpu, preserved_memory_gb=8) + load_model_as_complete(vae, target_device=gpu) + + # Get real history latents using the generator + real_history_latents = studio_module.current_generator.get_real_history_latents(history_latents, total_generated_latent_frames) + + if history_pixels is None: + history_pixels = vae_decode(real_history_latents, vae).cpu() + else: + section_latent_frames = studio_module.current_generator.get_section_latent_frames(latent_window_size, is_last_section) + overlapped_frames = latent_window_size * 4 - 3 + + # Get current pixels using the generator + current_pixels = studio_module.current_generator.get_current_pixels(real_history_latents, section_latent_frames, vae) + + # Update history pixels using the generator + history_pixels = studio_module.current_generator.update_history_pixels(history_pixels, current_pixels, overlapped_frames) + + print(f"{model_type} model section {section_idx+1}/{total_latent_sections}, history_pixels shape: {history_pixels.shape}") + + if not high_vram: + unload_complete_models() + + output_filename = os.path.join(output_dir, f'{job_id}_{total_generated_latent_frames}.mp4') + save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=settings.get("mp4_crf")) + print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}') + stream_to_use.output_queue.push(('file', output_filename)) + + if is_last_section: + break + + section_idx += 1 # PROMPT BLENDING: increment section index + + # We'll handle combining the videos after the entire generation is complete + # This section intentionally left empty to remove the in-process combination + # --- END Main generation loop --- + + magcache = studio_module.current_generator.transformer.magcache + if magcache is not None: + if magcache.is_calibrating: + output_file = os.path.join(settings.get("output_dir"), "magcache_configuration.txt") + print(f"MagCache calibration job complete. Appending stats to configuration file: {output_file}") + magcache.append_calibration_to_file(output_file) + elif magcache.is_enabled: + print(f"MagCache ({100.0 * magcache.total_cache_hits / magcache.total_cache_requests:.2f}%) skipped {magcache.total_cache_hits} of {magcache.total_cache_requests} steps.") + studio_module.current_generator.transformer.uninstall_magcache() + magcache = None + + # Handle the results + result = pipeline.handle_results(job_params, output_filename) + + # Unload all LoRAs after generation completed + if selected_loras: + print("Unloading all LoRAs after generation completed") + studio_module.current_generator.unload_loras() + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + except Exception as e: + traceback.print_exc() + # Unload all LoRAs after error + if studio_module.current_generator is not None and selected_loras: + print("Unloading all LoRAs after error") + studio_module.current_generator.unload_loras() + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + stream_to_use.output_queue.push(('error', f"Error during generation: {traceback.format_exc()}")) + if not high_vram: + # Ensure all models including the potentially active transformer are unloaded on error + unload_complete_models( + text_encoder, text_encoder_2, image_encoder, vae, + studio_module.current_generator.transformer if studio_module.current_generator else None + ) + finally: + # This finally block is associated with the main try block (starts around line 154) + if settings.get("clean_up_videos"): + try: + video_files = [ + f for f in os.listdir(output_dir) + if f.startswith(f"{job_id}_") and f.endswith(".mp4") + ] + print(f"Video files found for cleanup: {video_files}") + if video_files: + def get_frame_count(filename): + try: + # Handles filenames like jobid_123.mp4 + return int(filename.replace(f"{job_id}_", "").replace(".mp4", "")) + except Exception: + return -1 + video_files_sorted = sorted(video_files, key=get_frame_count) + print(f"Sorted video files: {video_files_sorted}") + final_video = video_files_sorted[-1] + for vf in video_files_sorted[:-1]: + full_path = os.path.join(output_dir, vf) + try: + os.remove(full_path) + print(f"Deleted intermediate video: {full_path}") + except Exception as e: + print(f"Failed to delete {full_path}: {e}") + except Exception as e: + print(f"Error during video cleanup: {e}") + + # Check if the user wants to combine the source video with the generated video + # This is done after the video cleanup routine to ensure the combined video is not deleted + # RT_BORG: Retain (but suppress) this original way to combine videos until the new combiner is proven. + combine_v1 = False + if combine_v1 and (model_type == "Video" or model_type == "Video F1") and combine_with_source and job_params.get('input_image_path'): + print("Creating combined video with source and generated content...") + try: + input_video_path = job_params.get('input_image_path') + if input_video_path and os.path.exists(input_video_path): + final_video_path_for_combine = None # Use a different variable name to avoid conflict + video_files_for_combine = [ + f for f in os.listdir(output_dir) + if f.startswith(f"{job_id}_") and f.endswith(".mp4") and "combined" not in f + ] + + if video_files_for_combine: + def get_frame_count_for_combine(filename): # Renamed to avoid conflict + try: + return int(filename.replace(f"{job_id}_", "").replace(".mp4", "")) + except Exception: + return float('inf') + + video_files_sorted_for_combine = sorted(video_files_for_combine, key=get_frame_count_for_combine) + if video_files_sorted_for_combine: # Check if the list is not empty + final_video_path_for_combine = os.path.join(output_dir, video_files_sorted_for_combine[-1]) + + if final_video_path_for_combine and os.path.exists(final_video_path_for_combine): + combined_output_filename = os.path.join(output_dir, f'{job_id}_combined_v1.mp4') + combined_result = None + try: + if hasattr(studio_module.current_generator, 'combine_videos'): + print(f"Using VideoModelGenerator.combine_videos to create side-by-side comparison") + combined_result = studio_module.current_generator.combine_videos( + source_video_path=input_video_path, + generated_video_path=final_video_path_for_combine, # Use the correct variable + output_path=combined_output_filename + ) + + if combined_result: + print(f"Combined video saved to: {combined_result}") + stream_to_use.output_queue.push(('file', combined_result)) + else: + print("Failed to create combined video, falling back to direct ffmpeg method") + combined_result = None + else: + print("VideoModelGenerator does not have combine_videos method. Using fallback method.") + except Exception as e_combine: # Use a different exception variable name + print(f"Error in combine_videos method: {e_combine}") + print("Falling back to direct ffmpeg method") + combined_result = None + + if not combined_result: + print("Using fallback method to combine videos") + from modules.toolbox.toolbox_processor import VideoProcessor + from modules.toolbox.message_manager import MessageManager + + message_manager = MessageManager() + # Pass settings.settings if it exists, otherwise pass the settings object + video_processor_settings = settings.settings if hasattr(settings, 'settings') else settings + video_processor = VideoProcessor(message_manager, video_processor_settings) + ffmpeg_exe = video_processor.ffmpeg_exe + + if ffmpeg_exe: + print(f"Using ffmpeg at: {ffmpeg_exe}") + import subprocess + temp_list_file = os.path.join(output_dir, f'{job_id}_filelist.txt') + with open(temp_list_file, 'w') as f: + f.write(f"file '{input_video_path}'\n") + f.write(f"file '{final_video_path_for_combine}'\n") # Use the correct variable + + ffmpeg_cmd = [ + ffmpeg_exe, "-y", "-f", "concat", "-safe", "0", + "-i", temp_list_file, "-c", "copy", combined_output_filename + ] + print(f"Running ffmpeg command: {' '.join(ffmpeg_cmd)}") + subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True) + if os.path.exists(temp_list_file): + os.remove(temp_list_file) + print(f"Combined video saved to: {combined_output_filename}") + stream_to_use.output_queue.push(('file', combined_output_filename)) + else: + print("FFmpeg executable not found. Cannot combine videos.") + else: + print(f"Final video not found for combining with source: {final_video_path_for_combine}") + else: + print(f"Input video path not found: {input_video_path}") + except Exception as e_combine_outer: # Use a different exception variable name + print(f"Error combining videos: {e_combine_outer}") + traceback.print_exc() + + # Combine input frames (resized and center cropped if needed) with final generated history_pixels tensor sequentially --- + # This creates ID_combined.mp4 + # RT_BORG: Be sure to add this check if we decide to retain the processed input frames for "small" input videos + # and job_params.get('input_frames_resized_np') is not None + if (model_type == "Video" or model_type == "Video F1") and combine_with_source and history_pixels is not None: + print(f"Creating combined video ({job_id}_combined.mp4) with processed input frames and generated history_pixels tensor...") + try: + # input_frames_resized_np = job_params.get('input_frames_resized_np') + + # RT_BORG: I cringe calliing methods on BaseModelGenerator that only exist on VideoBaseGenerator, until we refactor + input_frames_resized_np, fps, target_height, target_width = studio_module.current_generator.extract_video_frames( + is_for_encode=False, + video_path=job_params['input_image'], + resolution=job_params['resolutionW'], + no_resize=False, + input_files_dir=job_params['input_files_dir'] + ) + + # history_pixels is (B, C, T, H, W), float32, [-1,1], on CPU + if input_frames_resized_np is not None and history_pixels.numel() > 0 : # Check if history_pixels is not empty + combined_sequential_output_filename = os.path.join(output_dir, f'{job_id}_combined.mp4') + + # fps variable should be from the video_encode call earlier. + input_video_fps_for_combine = fps + current_crf = settings.get("mp4_crf", 16) + + # Call the new function from video_tools.py + combined_sequential_result_path = combine_videos_sequentially_from_tensors( + processed_input_frames_np=input_frames_resized_np, + generated_frames_pt=history_pixels, + output_path=combined_sequential_output_filename, + target_fps=input_video_fps_for_combine, + crf_value=current_crf + ) + if combined_sequential_result_path: + stream_to_use.output_queue.push(('file', combined_sequential_result_path)) + except Exception as e: + print(f"Error creating combined video ({job_id}_combined.mp4): {e}") + traceback.print_exc() + + # Final verification of LoRA state + if studio_module.current_generator and studio_module.current_generator.transformer: + # Verify LoRA state + has_loras = False + if hasattr(studio_module.current_generator.transformer, 'peft_config'): + adapter_names = list(studio_module.current_generator.transformer.peft_config.keys()) if studio_module.current_generator.transformer.peft_config else [] + if adapter_names: + has_loras = True + print(f"Transformer has LoRAs: {', '.join(adapter_names)}") + else: + print(f"Transformer has no LoRAs in peft_config") + else: + print(f"Transformer has no peft_config attribute") + + # Check for any LoRA modules + for name, module in studio_module.current_generator.transformer.named_modules(): + if hasattr(module, 'lora_A') and module.lora_A: + has_loras = True + if hasattr(module, 'lora_B') and module.lora_B: + has_loras = True + + if not has_loras: + print(f"No LoRA components found in transformer") + + stream_to_use.output_queue.push(('end', None)) + return diff --git a/modules/prompt_handler.py b/modules/prompt_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ce826b890ea1d3733173c9b6d5d8c7d2e71306 --- /dev/null +++ b/modules/prompt_handler.py @@ -0,0 +1,164 @@ +import re +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class PromptSection: + """Represents a section of the prompt with specific timing information""" + prompt: str + start_time: float = 0 # in seconds + end_time: Optional[float] = None # in seconds, None means until the end + + +def snap_to_section_boundaries(prompt_sections: List[PromptSection], latent_window_size: int, fps: int = 30) -> List[PromptSection]: + """ + Adjust timestamps to align with model's internal section boundaries + + Args: + prompt_sections: List of PromptSection objects + latent_window_size: Size of the latent window used in the model + fps: Frames per second (default: 30) + + Returns: + List of PromptSection objects with aligned timestamps + """ + section_duration = (latent_window_size * 4 - 3) / fps # Duration of one section in seconds + + aligned_sections = [] + for section in prompt_sections: + # Snap start time to nearest section boundary + aligned_start = round(section.start_time / section_duration) * section_duration + + # Snap end time to nearest section boundary + aligned_end = None + if section.end_time is not None: + aligned_end = round(section.end_time / section_duration) * section_duration + + # Ensure minimum section length + if aligned_end is not None and aligned_end <= aligned_start: + aligned_end = aligned_start + section_duration + + aligned_sections.append(PromptSection( + prompt=section.prompt, + start_time=aligned_start, + end_time=aligned_end + )) + + return aligned_sections + + +def parse_timestamped_prompt(prompt_text: str, total_duration: float, latent_window_size: int = 9, generation_type: str = "Original") -> List[PromptSection]: + """ + Parse a prompt with timestamps in the format [0s-2s: text] or [3s: text] + + Args: + prompt_text: The input prompt text with optional timestamp sections + total_duration: Total duration of the video in seconds + latent_window_size: Size of the latent window used in the model + generation_type: Type of generation ("Original" or "F1") + + Returns: + List of PromptSection objects with timestamps aligned to section boundaries + and reversed to account for reverse generation (only for Original type) + """ + # Default prompt for the entire duration if no timestamps are found + if "[" not in prompt_text or "]" not in prompt_text: + return [PromptSection(prompt=prompt_text.strip())] + + sections = [] + # Find all timestamp sections [time: text] + timestamp_pattern = r'\[(\d+(?:\.\d+)?s)(?:-(\d+(?:\.\d+)?s))?\s*:\s*(.*?)\]' + regular_text = prompt_text + + for match in re.finditer(timestamp_pattern, prompt_text): + start_time_str = match.group(1) + end_time_str = match.group(2) + section_text = match.group(3).strip() + + # Convert time strings to seconds + start_time = float(start_time_str.rstrip('s')) + end_time = float(end_time_str.rstrip('s')) if end_time_str else None + + sections.append(PromptSection( + prompt=section_text, + start_time=start_time, + end_time=end_time + )) + + # Remove the processed section from regular_text + regular_text = regular_text.replace(match.group(0), "") + + # If there's any text outside of timestamp sections, use it as a default for the entire duration + regular_text = regular_text.strip() + if regular_text: + sections.append(PromptSection( + prompt=regular_text, + start_time=0, + end_time=None + )) + + # Sort sections by start time + sections.sort(key=lambda x: x.start_time) + + # Fill in end times if not specified + for i in range(len(sections) - 1): + if sections[i].end_time is None: + sections[i].end_time = sections[i+1].start_time + + # Set the last section's end time to the total duration if not specified + if sections and sections[-1].end_time is None: + sections[-1].end_time = total_duration + + # Snap timestamps to section boundaries + sections = snap_to_section_boundaries(sections, latent_window_size) + + # Only reverse timestamps for Original generation type + if generation_type in ("Original", "Original with Endframe", "Video"): + # Now reverse the timestamps to account for reverse generation + reversed_sections = [] + for section in sections: + reversed_start = total_duration - section.end_time if section.end_time is not None else 0 + reversed_end = total_duration - section.start_time + reversed_sections.append(PromptSection( + prompt=section.prompt, + start_time=reversed_start, + end_time=reversed_end + )) + + # Sort the reversed sections by start time + reversed_sections.sort(key=lambda x: x.start_time) + return reversed_sections + + return sections + + +def get_section_boundaries(latent_window_size: int = 9, count: int = 10) -> str: + """ + Calculate and format section boundaries for UI display + + Args: + latent_window_size: Size of the latent window used in the model + count: Number of boundaries to display + + Returns: + Formatted string of section boundaries + """ + section_duration = (latent_window_size * 4 - 3) / 30 + return ", ".join([f"{i*section_duration:.1f}s" for i in range(count)]) + + +def get_quick_prompts() -> List[List[str]]: + """ + Get a list of example timestamped prompts + + Returns: + List of example prompts formatted for Gradio Dataset + """ + prompts = [ + '[0s: The person waves hello] [2s: The person jumps up and down] [4s: The person does a spin]', + '[0s: The person raises both arms slowly] [2s: The person claps hands enthusiastically]', + '[0s: Person gives thumbs up] [1.1s: Person smiles and winks] [2.2s: Person shows two thumbs down]', + '[0s: Person looks surprised] [1.1s: Person raises arms above head] [2.2s-3.3s: Person puts hands on hips]' + ] + return [[x] for x in prompts] diff --git a/modules/settings.py b/modules/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..e3bd25df7b505c993e6b81a853a15894c3777f6b --- /dev/null +++ b/modules/settings.py @@ -0,0 +1,88 @@ +import json +from pathlib import Path +from typing import Dict, Any, Optional +import os + +class Settings: + def __init__(self): + # Get the project root directory (where settings.py is located) + project_root = Path(__file__).parent.parent + + self.settings_file = project_root / ".framepack" / "settings.json" + self.settings_file.parent.mkdir(parents=True, exist_ok=True) + + # Set default paths relative to project root + self.default_settings = { + "save_metadata": True, + "gpu_memory_preservation": 6, + "output_dir": str(project_root / "outputs"), + "metadata_dir": str(project_root / "outputs"), + "lora_dir": str(project_root / "loras"), + "gradio_temp_dir": str(project_root / "temp"), + "input_files_dir": str(project_root / "input_files"), # New setting for input files + "auto_save_settings": True, + "gradio_theme": "default", + "mp4_crf": 16, + "clean_up_videos": True, + "override_system_prompt": False, + "auto_cleanup_on_startup": False, # ADDED: New setting for startup cleanup + "latents_display_top": False, # NEW: Control latents preview position (False = right column, True = top of interface) + "system_prompt_template": "{\"template\": \"<|start_header_id|>system<|end_header_id|>\\n\\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n{}<|eot_id|>\", \"crop_start\": 95}", + "startup_model_type": "None", + "startup_preset_name": None, + "enhancer_prompt_template": """You are a creative assistant for a text-to-video generator. Your task is to take a user's prompt and make it more descriptive, vivid, and detailed. Focus on visual elements. Do not change the core action, but embellish it. + +User prompt: "{text_to_enhance}" + +Enhanced prompt:""" + } + self.settings = self.load_settings() + + def load_settings(self) -> Dict[str, Any]: + """Load settings from file or return defaults""" + if self.settings_file.exists(): + try: + with open(self.settings_file, 'r') as f: + loaded_settings = json.load(f) + # Merge with defaults to ensure all settings exist + settings = self.default_settings.copy() + settings.update(loaded_settings) + return settings + except Exception as e: + print(f"Error loading settings: {e}") + return self.default_settings.copy() + return self.default_settings.copy() + + def save_settings(self, **kwargs): + """Save settings to file. Accepts keyword arguments for any settings to update.""" + # Update self.settings with any provided keyword arguments + self.settings.update(kwargs) + # Ensure all default fields are present + for k, v in self.default_settings.items(): + self.settings.setdefault(k, v) + + # Ensure directories exist for relevant fields + for dir_key in ["output_dir", "metadata_dir", "lora_dir", "gradio_temp_dir"]: + dir_path = self.settings.get(dir_key) + if dir_path: + os.makedirs(dir_path, exist_ok=True) + + # Save to file + with open(self.settings_file, 'w') as f: + json.dump(self.settings, f, indent=4) + + def get(self, key: str, default: Any = None) -> Any: + """Get a setting value""" + return self.settings.get(key, default) + + def set(self, key: str, value: Any) -> None: + """Set a setting value""" + self.settings[key] = value + if self.settings.get("auto_save_settings", True): + self.save_settings() + + def update(self, settings: Dict[str, Any]) -> None: + """Update multiple settings at once""" + self.settings.update(settings) + if self.settings.get("auto_save_settings", True): + self.save_settings() diff --git a/modules/toolbox/RIFE/IFNet_HDv3.py b/modules/toolbox/RIFE/IFNet_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..e79853ceddaa762d8fdb343689ad6dc099a29c27 --- /dev/null +++ b/modules/toolbox/RIFE/IFNet_HDv3.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .warplayer import warp + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True, + ), + nn.PReLU(out_planes), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.PReLU(out_planes), + ) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock0 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock1 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock2 = nn.Sequential(conv(c, c), conv(c, c)) + self.convblock3 = nn.Sequential(conv(c, c), conv(c, c)) + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(c, c // 2, 4, 2, 1), + nn.PReLU(c // 2), + nn.ConvTranspose2d(c // 2, 4, 4, 2, 1), + ) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(c, c // 2, 4, 2, 1), + nn.PReLU(c // 2), + nn.ConvTranspose2d(c // 2, 1, 4, 2, 1), + ) + + def forward(self, x, flow, scale=1): + x = F.interpolate( + x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + flow = ( + F.interpolate( + flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + * 1.0 + / scale + ) + feat = self.conv0(torch.cat((x, flow), 1)) + feat = self.convblock0(feat) + feat + feat = self.convblock1(feat) + feat + feat = self.convblock2(feat) + feat + feat = self.convblock3(feat) + feat + flow = self.conv1(feat) + mask = self.conv2(feat) + flow = ( + F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) + * scale + ) + mask = F.interpolate( + mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False + ) + return flow, mask + + +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7 + 4, c=90) + self.block1 = IFBlock(7 + 4, c=90) + self.block2 = IFBlock(7 + 4, c=90) + self.block_tea = IFBlock(10 + 4, c=90) + # self.contextnet = Contextnet() + # self.unet = Unet() + + def forward(self, x, scale_list=[4, 2, 1], training=False): + if training == False: + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = (x[:, :4]).detach() * 0 + mask = (x[:, :1]).detach() * 0 + loss_cons = 0 + block = [self.block0, self.block1, self.block2] + for i in range(3): + f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i]) + f1, m1 = block[i]( + torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), + torch.cat((flow[:, 2:4], flow[:, :2]), 1), + scale=scale_list[i], + ) + flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2 + mask = mask + (m0 + (-m1)) / 2 + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + """ + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, 1:4] * 2 - 1 + """ + for i in range(3): + mask_list[i] = torch.sigmoid(mask_list[i]) + merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) + # merged[i] = torch.clamp(merged[i] + res, 0, 1) + return flow_list, mask_list[2], merged diff --git a/modules/toolbox/RIFE/RIFE_HDv3.py b/modules/toolbox/RIFE/RIFE_HDv3.py new file mode 100644 index 0000000000000000000000000000000000000000..11c0001e16282c88795e9a971216841bd2f4d044 --- /dev/null +++ b/modules/toolbox/RIFE/RIFE_HDv3.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim import AdamW +import numpy as np +import itertools +from .warplayer import warp +from torch.nn.parallel import DistributedDataParallel as DDP +from .IFNet_HDv3 import * +from .loss import * +import devicetorch +device = devicetorch.get(torch) + + +class Model: + def __init__(self, local_rank=-1): + self.flownet = IFNet() + self.device() + self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) + self.epe = EPE() + # self.vgg = VGGPerceptualLoss().to(device) + self.sobel = SOBEL() + if local_rank != -1: + self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) + + def train(self): + self.flownet.train() + + def eval(self): + self.flownet.eval() + + def device(self): + self.flownet.to(device) + + def load_model(self, path, rank=0): + def convert(param): + if rank == -1: + return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} + else: + return param + + if rank <= 0: + model_path = "{}/flownet.pkl".format(path) + # Check PyTorch version to safely use weights_only + from packaging import version + use_weights_only = version.parse(torch.__version__) >= version.parse("1.13") + + load_kwargs = {} + if not torch.cuda.is_available(): + load_kwargs['map_location'] = "cpu" + + if use_weights_only: + # For modern PyTorch, be explicit and safe + load_kwargs['weights_only'] = True + # print(f"PyTorch >= 1.13 detected. Loading RIFE model with weights_only=True.") + state_dict = torch.load(model_path, **load_kwargs) + else: + # For older PyTorch, load the old way + print(f"PyTorch < 1.13 detected. Loading RIFE model using legacy method.") + state_dict = torch.load(model_path, **load_kwargs) + + self.flownet.load_state_dict(convert(state_dict)) + + def inference(self, img0, img1, scale=1.0): + imgs = torch.cat((img0, img1), 1) + scale_list = [4 / scale, 2 / scale, 1 / scale] + flow, mask, merged = self.flownet(imgs, scale_list) + return merged[2] + + def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): + for param_group in self.optimG.param_groups: + param_group["lr"] = learning_rate + img0 = imgs[:, :3] + img1 = imgs[:, 3:] + if training: + self.train() + else: + self.eval() + scale = [4, 2, 1] + flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) + loss_l1 = (merged[2] - gt).abs().mean() + loss_smooth = self.sobel(flow[2], flow[2] * 0).mean() + # loss_vgg = self.vgg(merged[2], gt) + if training: + self.optimG.zero_grad() + loss_G = loss_cons + loss_smooth * 0.1 + loss_G.backward() + self.optimG.step() + else: + flow_teacher = flow[2] + return merged[2], { + "mask": mask, + "flow": flow[2][:, :2], + "loss_l1": loss_l1, + "loss_cons": loss_cons, + "loss_smooth": loss_smooth, + } diff --git a/modules/toolbox/RIFE/__int__.py b/modules/toolbox/RIFE/__int__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/toolbox/RIFE/loss.py b/modules/toolbox/RIFE/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c0141ec18848ee037cb20deb00e5416aa5d78190 --- /dev/null +++ b/modules/toolbox/RIFE/loss.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import devicetorch +device = devicetorch.get(torch) + + +class EPE(nn.Module): + def __init__(self): + super(EPE, self).__init__() + + def forward(self, flow, gt, loss_mask): + loss_map = (flow - gt.detach()) ** 2 + loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 + return loss_map * loss_mask + + +class Ternary(nn.Module): + def __init__(self): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf**2) + return transf_norm + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + def forward(self, img0, img1): + img0 = self.transform(self.rgb2gray(img0)) + img1 = self.transform(self.rgb2gray(img1)) + return self.hamming(img0, img1) * self.valid_mask(img0, 1) + + +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor( + [ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ] + ).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat([pred.reshape(N * C, 1, H, W), gt.reshape(N * C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[: N * C], sobel_stack_x[N * C :] + pred_Y, gt_Y = sobel_stack_y[: N * C], sobel_stack_y[N * C :] + + L1X, L1Y = torch.abs(pred_X - gt_X), torch.abs(pred_Y - gt_Y) + loss = L1X + L1Y + return loss + + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + self.requires_grad = False + + +class VGGPerceptualLoss(torch.nn.Module): + def __init__(self, rank=0): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + # self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).to(device) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X, Y, indices=None): + X = self.normalize(X) + Y = self.normalize(Y) + indices = [2, 7, 12, 21, 30] + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + X = self.vgg_pretrained_features[i](X) + Y = self.vgg_pretrained_features[i](Y) + if (i + 1) in indices: + loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1 + k += 1 + return loss + + +if __name__ == "__main__": + img0 = torch.zeros(3, 3, 256, 256).float().to(device) + img1 = torch.tensor(np.random.normal(0, 1, (3, 3, 256, 256))).float().to(device) + ternary_loss = Ternary() + print(ternary_loss(img0, img1).shape) diff --git a/modules/toolbox/RIFE/warplayer.py b/modules/toolbox/RIFE/warplayer.py new file mode 100644 index 0000000000000000000000000000000000000000..38b14451d010b85d01ff8d7cc6b8005baba233eb --- /dev/null +++ b/modules/toolbox/RIFE/warplayer.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +# import devicetorch +# device = devicetorch.get(torch) + +backwarp_tenGrid = {} # Cache for grid tensors + +def warp(tenInput, tenFlow): + # The key for caching should be based on tenFlow's properties, including its device + k = (str(tenFlow.device), str(tenFlow.size())) + + if k not in backwarp_tenGrid: + # Create grid tensors on the same device as tenFlow + flow_device = tenFlow.device + + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=flow_device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=flow_device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + # Concatenate; the result will be on flow_device if inputs are. + # No need for an extra .to(device) if flow_device is used for components. + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1) + + # tenFlow is already on its correct device. + # backwarp_tenGrid[k] is now also on that same device. + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + # Ensure grid used by grid_sample is on the same device as tenInput and tenFlow. + # backwarp_tenGrid[k] is already on tenFlow.device. + # tenFlow is on tenFlow.device. + # If tenInput can be on a different device than tenFlow, that's a separate issue. + # Assuming tenInput and tenFlow are on the same device for grid_sample. + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + + return torch.nn.functional.grid_sample( + input=tenInput, grid=g, mode="bilinear", padding_mode="border", align_corners=True + ) \ No newline at end of file diff --git a/modules/toolbox/esrgan_core.py b/modules/toolbox/esrgan_core.py new file mode 100644 index 0000000000000000000000000000000000000000..9055eeb9916b5cf25ff0630be795945bcaa43036 --- /dev/null +++ b/modules/toolbox/esrgan_core.py @@ -0,0 +1,483 @@ +import os +import torch +import gc +import devicetorch +import warnings +import traceback + +from pathlib import Path +from huggingface_hub import snapshot_download +from basicsr.archs.rrdbnet_arch import RRDBNet +from realesrgan import RealESRGANer +from realesrgan.archs.srvgg_arch import SRVGGNetCompact +from basicsr.utils.download_util import load_file_from_url # Import for direct downloads + +# Conditional import for GFPGAN +try: + from gfpgan import GFPGANer + GFPGAN_AVAILABLE = True +except ImportError: + GFPGAN_AVAILABLE = False + +from .message_manager import MessageManager + +_MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) +MODEL_ESRGAN_PATH = _MODULE_DIR / "model_esrgan" +# Define a path for GFPGAN models, can be within MODEL_ESRGAN_PATH or separate +MODEL_GFPGAN_PATH = _MODULE_DIR / "model_gfpgan" + +class ESRGANUpscaler: + def __init__(self, message_manager: MessageManager, device: torch.device): + self.message_manager = message_manager + self.device = device + self.model_dir = Path(MODEL_ESRGAN_PATH) + self.gfpgan_model_dir = Path(MODEL_GFPGAN_PATH) # GFPGAN model directory + os.makedirs(self.model_dir, exist_ok=True) + os.makedirs(self.gfpgan_model_dir, exist_ok=True) # Ensure GFPGAN model dir exists + + self.supported_models = { + + "RealESRGAN_x2plus": { + "filename": "RealESRGAN_x2plus.pth", + "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + "hf_repo_id": None, + "scale": 2, + "model_class": RRDBNet, + "model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2), + "description": "General purpose. Faster than x4 models due to smaller native output. Good for moderate upscaling." + }, + "RealESRGAN_x4plus": { + "filename": "RealESRGAN_x4plus.pth", + "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + "hf_repo_id": None, + "scale": 4, + "model_class": RRDBNet, + "model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + "description": "General purpose. Prioritizes sharpness & detail. Good default for most videos." + }, + "RealESRNet_x4plus": { + "filename": "RealESRNet_x4plus.pth", + "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth", + "hf_repo_id": None, + "scale": 4, + "model_class": RRDBNet, + "model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), + "description": "Similar to RealESRGAN_x4plus, but trained for higher fidelity, often yielding smoother results." + }, + "RealESR-general-x4v3": { + "filename": "realesr-general-x4v3.pth", # Main model + "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + "wdn_filename": "realesr-general-wdn-x4v3.pth", # Companion WDN model + "wdn_file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + "scale": 4, "model_class": SRVGGNetCompact, + "model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu'), + "description": "Versatile SRVGG-based. Balances detail & naturalness. Has adjustable denoise strength." # Updated description + }, + "RealESRGAN_x4plus_anime_6B": { + "filename": "RealESRGAN_x4plus_anime_6B.pth", + "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + "hf_repo_id": None, + "scale": 4, + "model_class": RRDBNet, + "model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4), + "description": "Optimized for anime. Lighter 6-block version of x4plus for faster anime upscaling." + }, + "RealESR_AnimeVideo_v3": { + "filename": "realesr-animevideov3.pth", + "file_url": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + "hf_repo_id": None, + "scale": 4, + "model_class": SRVGGNetCompact, + "model_params": dict(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'), + "description": "Specialized SRVGG-based model for anime. Often excels with animated content." + } + } + + self.upsamplers: dict[str, dict[str, RealESRGANer | int | None]] = {} + self.face_enhancer: GFPGANer | None = None # For GFPGAN + + def _ensure_model_downloaded(self, model_key: str, target_dir: Path | None = None, is_gfpgan: bool = False, is_wdn_companion: bool = False) -> bool: + # Modified to handle WDN companion model download for RealESR-general-x4v3 + if target_dir is None: + current_model_dir = self.gfpgan_model_dir if is_gfpgan else self.model_dir + else: + current_model_dir = target_dir + + model_info_source = {} + actual_model_filename = "" + + if is_gfpgan: + model_info_source = { + "filename": "GFPGANv1.4.pth", + "file_url": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth", + "hf_repo_id": None + } + actual_model_filename = model_info_source["filename"] + else: + if model_key not in self.supported_models: + self.message_manager.add_error(f"ESRGAN model key '{model_key}' not supported.") + return False + + model_details = self.supported_models[model_key] + if is_wdn_companion: + if "wdn_filename" not in model_details or "wdn_file_url" not in model_details: + self.message_manager.add_error(f"WDN companion model details missing for '{model_key}'.") + return False + model_info_source = { + "filename": model_details["wdn_filename"], + "file_url": model_details["wdn_file_url"], + "hf_repo_id": None # Assuming direct URL for WDN for now + } + actual_model_filename = model_details["wdn_filename"] + else: # Regular ESRGAN model + model_info_source = model_details + actual_model_filename = model_details["filename"] + + model_path = current_model_dir / actual_model_filename + + if not model_path.exists(): + log_prefix = "WDN " if is_wdn_companion else "" + self.message_manager.add_message(f"{log_prefix}Model '{actual_model_filename}' not found. Downloading...") + try: + downloaded_successfully = False + if "file_url" in model_info_source and model_info_source["file_url"]: + urls_to_try = model_info_source["file_url"] + if isinstance(urls_to_try, str): urls_to_try = [urls_to_try] + + for url in urls_to_try: + self.message_manager.add_message(f"Attempting download from URL: {url}") + try: + load_file_from_url( + url=url, model_dir=str(current_model_dir), + progress=True, file_name=actual_model_filename + ) + if model_path.exists(): + downloaded_successfully = True + self.message_manager.add_success(f"{log_prefix}Model '{actual_model_filename}' downloaded from URL.") + break + except Exception as e_url: + self.message_manager.add_warning(f"Failed to download from {url}: {e_url}. Trying next source.") + continue + + if not downloaded_successfully and "hf_repo_id" in model_info_source and model_info_source["hf_repo_id"]: + self.message_manager.add_message(f"Attempting download from Hugging Face Hub: {model_info_source['hf_repo_id']}") + snapshot_download( + repo_id=model_info_source["hf_repo_id"], allow_patterns=[actual_model_filename], + local_dir=current_model_dir, local_dir_use_symlinks=False + ) + if model_path.exists(): + downloaded_successfully = True + self.message_manager.add_success(f"{log_prefix}Model '{actual_model_filename}' downloaded from Hugging Face Hub.") + + if not downloaded_successfully: + self.message_manager.add_error(f"All download attempts failed for '{actual_model_filename}'.") + return False + except Exception as e: + self.message_manager.add_error(f"Failed to download {log_prefix}model '{actual_model_filename}': {e}") + self.message_manager.add_error(traceback.format_exc()) + return False + return True + + def load_model(self, model_key: str, tile_size: int = 0, denoise_strength: float | None = None) -> RealESRGANer | None: + if model_key not in self.supported_models: + self.message_manager.add_error(f"ESRGAN model key '{model_key}' not supported.") + return None + + # Check if model is already loaded with the same configuration + current_config_signature = (tile_size, denoise_strength if model_key == "RealESR-general-x4v3" else None) + + if model_key in self.upsamplers: + existing_config = self.upsamplers[model_key] + existing_config_signature = ( + existing_config.get('tile_size', 0), + existing_config.get('denoise_strength') if model_key == "RealESR-general-x4v3" else None + ) + + if existing_config.get("upsampler") is not None and existing_config_signature == current_config_signature: + log_tile = f"Tile: {str(tile_size) if tile_size > 0 else 'Auto'}" + log_dni = f", DNI: {denoise_strength:.2f}" if denoise_strength is not None and model_key == "RealESR-general-x4v3" else "" + self.message_manager.add_message(f"ESRGAN model '{model_key}' ({log_tile}{log_dni}) already loaded.") + return existing_config["upsampler"] + elif existing_config.get("upsampler") is not None and existing_config_signature != current_config_signature: + self.message_manager.add_message( + f"ESRGAN model '{model_key}' config changed. Unloading to reload with new settings." + ) + self.unload_model(model_key) + + # Ensure main model is downloaded + if not self._ensure_model_downloaded(model_key): + return None + + model_info = self.supported_models[model_key] + model_path_for_upsampler = str(self.model_dir / model_info["filename"]) + dni_weight_for_upsampler = None + + log_msg_parts = [ + f"Loading ESRGAN model '{model_info['filename']}' (Key: {model_key}, Scale: {model_info['scale']}x", + f"Tile: {str(tile_size) if tile_size > 0 else 'Auto'}" + ] + + # Specific handling for RealESR-general-x4v3 with denoise_strength + if model_key == "RealESR-general-x4v3" and denoise_strength is not None and 0.0 <= denoise_strength < 1.0: + # Denoise strength 1.0 means use only the main model, so no DNI. + # Denoise strength < 0.0 is invalid. + if "wdn_filename" not in model_info or "wdn_file_url" not in model_info: + self.message_manager.add_error(f"WDN companion model details missing for '{model_key}'. Cannot apply denoise strength.") + return None # Or fallback to no DNI? For now, error. + + # Ensure WDN companion model is downloaded + if not self._ensure_model_downloaded(model_key, is_wdn_companion=True): + self.message_manager.add_error(f"Failed to download WDN companion for '{model_key}'. Cannot apply denoise strength.") + return None + + wdn_model_path_str = str(self.model_dir / model_info["wdn_filename"]) + model_path_for_upsampler = [model_path_for_upsampler, wdn_model_path_str] # Pass list of paths + dni_weight_for_upsampler = [denoise_strength, 1.0 - denoise_strength] # [main_model_strength, wdn_model_strength] + log_msg_parts.append(f"DNI Strength: {denoise_strength:.2f}") + + log_msg_parts.append(f") to device: {self.device}...") + self.message_manager.add_message(" ".join(log_msg_parts)) + + try: + model_params_with_correct_scale = model_info["model_params"].copy() + if "scale" in model_params_with_correct_scale: model_params_with_correct_scale["scale"] = model_info["scale"] + elif "upscale" in model_params_with_correct_scale: model_params_with_correct_scale["upscale"] = model_info["scale"] + else: model_params_with_correct_scale["scale"] = model_info["scale"] + + model_arch = model_info["model_class"](**model_params_with_correct_scale) + + gpu_id_for_realesrgan = self.device.index if self.device.type == 'cuda' and self.device.index is not None else None + use_half_precision = True if self.device.type == 'cuda' else False + + with warnings.catch_warnings(): + # Suppress the TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD warning from RealESRGANer/basicsr + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected.*" + ) + # Suppress torchvision pretrained/weights warnings potentially triggered by basicsr + warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated.*") + warnings.filterwarnings("ignore", category=UserWarning, message="Arguments other than a weight enum or `None` for 'weights' are deprecated.*") + + upsampler = RealESRGANer( + scale=model_info["scale"], + model_path=model_path_for_upsampler, + dni_weight=dni_weight_for_upsampler, + model=model_arch, + tile=tile_size, + tile_pad=10, + pre_pad=0, + half=use_half_precision, + gpu_id=gpu_id_for_realesrgan + ) + + self.upsamplers[model_key] = { + "upsampler": upsampler, + "tile_size": tile_size, + "native_scale": model_info["scale"], + "denoise_strength": denoise_strength if model_key == "RealESR-general-x4v3" else None + } + self.message_manager.add_success(f"ESRGAN model '{model_info['filename']}' (Key: {model_key}) loaded successfully.") + return upsampler + except Exception as e: + self.message_manager.add_error(f"Failed to load ESRGAN model '{model_info['filename']}' (Key: {model_key}): {e}") + self.message_manager.add_error(traceback.format_exc()) + if model_key in self.upsamplers: del self.upsamplers[model_key] + return None + + def _load_face_enhancer(self, model_name="GFPGANv1.4.pth", bg_upsampler=None) -> bool: + if not GFPGAN_AVAILABLE: + self.message_manager.add_warning("GFPGAN library not available. Cannot load face enhancer.") + return False + if self.face_enhancer is not None: + # If bg_upsampler changed, we might need to re-init. For now, assume if loaded, it's fine or will be handled by caller. + if bg_upsampler is not None and hasattr(self.face_enhancer, 'bg_upsampler') and self.face_enhancer.bg_upsampler != bg_upsampler: + self.message_manager.add_message("GFPGAN face enhancer already loaded, but with a different background upsampler. Re-initializing GFPGAN...") + self._unload_face_enhancer() # Unload to reload with new bg_upsampler + else: + self.message_manager.add_message("GFPGAN face enhancer already loaded.") + return True + + + if not self._ensure_model_downloaded(model_key=model_name, is_gfpgan=True): + self.message_manager.add_error(f"Failed to download GFPGAN model '{model_name}'.") + return False + + gfpgan_model_path = str(self.gfpgan_model_dir / model_name) + self.message_manager.add_message(f"Loading GFPGAN face enhancer from {gfpgan_model_path}...") + try: + # --- ADDED: warnings.catch_warnings() context manager --- + with warnings.catch_warnings(): + # Suppress warnings from GFPGANer and its dependencies (facexlib) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected.*" + ) + warnings.filterwarnings("ignore", category=UserWarning, message="The parameter 'pretrained' is deprecated.*") + warnings.filterwarnings("ignore", category=UserWarning, message="Arguments other than a weight enum or `None` for 'weights' are deprecated.*") + + self.face_enhancer = GFPGANer( + model_path=gfpgan_model_path, + upscale=1, + arch='clean', + channel_multiplier=2, + bg_upsampler=bg_upsampler, + device=self.device + ) + self.message_manager.add_success("GFPGAN face enhancer loaded.") + return True + except Exception as e: + self.message_manager.add_error(f"Failed to load GFPGAN face enhancer: {e}") + self.message_manager.add_error(traceback.format_exc()) + self.face_enhancer = None + return False + + def _unload_face_enhancer(self): + if self.face_enhancer is not None: + self.message_manager.add_message("Unloading GFPGAN face enhancer...") + del self.face_enhancer + self.face_enhancer = None + gc.collect() + if self.device.type == 'cuda': + torch.cuda.empty_cache() + self.message_manager.add_success("GFPGAN face enhancer unloaded.") + else: + self.message_manager.add_message("GFPGAN face enhancer not loaded.") + + + def unload_model(self, model_key: str): + if model_key in self.upsamplers and self.upsamplers[model_key].get("upsampler") is not None: + config = self.upsamplers.pop(model_key) + upsampler_instance = config["upsampler"] + tile_s = config.get("tile_size", 0) + native_scale = config.get("native_scale", "N/A") # Get native_scale for logging + log_tile_size = str(tile_s) if tile_s > 0 else "Auto" + self.message_manager.add_message(f"Unloading ESRGAN model '{model_key}' (Scale: {native_scale}x, Tile: {log_tile_size})...") + + if self.face_enhancer and hasattr(self.face_enhancer, 'bg_upsampler') and self.face_enhancer.bg_upsampler == upsampler_instance: + self.message_manager.add_message("Unloading associated GFPGAN as its BG upsampler is being removed.") + self._unload_face_enhancer() + + del upsampler_instance + devicetorch.empty_cache(torch) + gc.collect() + self.message_manager.add_success(f"ESRGAN model '{model_key}' unloaded and memory cleared.") + else: + self.message_manager.add_message(f"ESRGAN model '{model_key}' not loaded, no need to unload.") + + def unload_all_models(self): + if not self.upsamplers and not self.face_enhancer: + self.message_manager.add_message("No ESRGAN or GFPGAN models currently loaded.") + return + + self.message_manager.add_message("Unloading all ESRGAN models...") + model_keys_to_unload = list(self.upsamplers.keys()) + for key in model_keys_to_unload: + if key in self.upsamplers: + config = self.upsamplers.pop(key) + upsampler_instance = config["upsampler"] + del upsampler_instance # type: ignore + + self._unload_face_enhancer() + + devicetorch.empty_cache(torch) + gc.collect() + self.message_manager.add_success("All ESRGAN and GFPGAN models unloaded and memory cleared.") + + def upscale_frame(self, frame_np_array, model_key: str, target_outscale_factor: float, enhance_face: bool = False): + """ + Upscales a single frame using the specified model and target output scale. + """ + config = self.upsamplers.get(model_key) + upsampler: RealESRGANer | None = None + current_tile_size = 0 + model_native_scale = 0 + + if config and config.get("upsampler"): + upsampler = config["upsampler"] # type: ignore + current_tile_size = config.get("tile_size", 0) # type: ignore + model_native_scale = config.get("native_scale", 0) # type: ignore + if model_native_scale == 0: + self.message_manager.add_error(f"Error: Native scale for model '{model_key}' is 0 or not found in config.") + return None + + if upsampler is None: + self.message_manager.add_warning( + f"ESRGAN model '{model_key}' not pre-loaded. Attempting to load now (with default Tile: Auto)..." + ) + tile_to_load_with = config.get("tile_size", 0) if config else 0 + upsampler = self.load_model(model_key, tile_size=tile_to_load_with) + if upsampler is None: + self.message_manager.add_error(f"Failed to auto-load ESRGAN model '{model_key}'. Cannot upscale.") + return None + + loaded_config = self.upsamplers.get(model_key) # Re-fetch config after load + if loaded_config: + current_tile_size = loaded_config.get("tile_size", 0) # type: ignore + model_native_scale = loaded_config.get("native_scale", 0) # type: ignore + if model_native_scale == 0: + self.message_manager.add_error(f"Error: Native scale for auto-loaded model '{model_key}' is 0.") + return None + else: + self.message_manager.add_error(f"Error: Config for auto-loaded model '{model_key}' not found.") + return None + + # Validate target_outscale_factor against model's native scale. + # Allow outscale from a small factor up to the model's native scale. + # You could allow slightly more (e.g., model_native_scale * 1.1) if you want to permit minor bicubic post-upscale. + # For now, strictly <= native_scale. + if not (0.25 <= target_outscale_factor <= model_native_scale): + self.message_manager.add_warning( + f"Target outscale factor {target_outscale_factor:.2f}x is outside the recommended range " + f"(0.25x to {model_native_scale:.2f}x) for model '{model_key}' (native {model_native_scale}x). " + f"Adjusting to model's native scale {model_native_scale:.2f}x." + ) + target_outscale_factor = float(model_native_scale) + + + if enhance_face: + if not self.face_enhancer or (hasattr(self.face_enhancer, 'bg_upsampler') and self.face_enhancer.bg_upsampler != upsampler): + self.message_manager.add_message("Face enhancement requested, loading/re-configuring GFPGAN...") + self._load_face_enhancer(bg_upsampler=upsampler) + + if not self.face_enhancer: + self.message_manager.add_warning("GFPGAN could not be loaded. Proceeding without face enhancement.") + enhance_face = False + + try: + img_bgr = frame_np_array[:, :, ::-1] + + outscale_for_enhance = float(target_outscale_factor) + + if enhance_face and self.face_enhancer: + if self.face_enhancer.upscale != 1: # Ensure GFPGAN is only cleaning, not upscaling itself in this pipeline path + self.message_manager.add_warning( + f"GFPGANer's internal upscale is {self.face_enhancer.upscale}, but for the 'Clean Face -> ESRGAN Upscale' pipeline, " + f"it should be 1. RealESRGAN will handle the main scaling to {target_outscale_factor:.2f}x." + ) + + _, _, cleaned_img_bgr = self.face_enhancer.enhance(img_bgr, has_aligned=False, only_center_face=False, paste_back=True) + output_bgr, _ = upsampler.enhance(cleaned_img_bgr, outscale=outscale_for_enhance) + else: + output_bgr, _ = upsampler.enhance(img_bgr, outscale=outscale_for_enhance) + + output_rgb = output_bgr[:, :, ::-1] + return output_rgb + except Exception as e: + tile_size_msg_part = str(current_tile_size) if current_tile_size > 0 else 'Auto' + face_msg_part = " + Face Enhance" if enhance_face else "" + self.message_manager.add_error( + f"Error during ESRGAN frame upscaling (Model: {model_key}{face_msg_part}, " + f"Target Scale: {target_outscale_factor:.2f}x, Native: {model_native_scale}x, Tile: {tile_size_msg_part}): {e}" + ) + self.message_manager.add_error(traceback.format_exc()) + if "out of memory" in str(e).lower() and self.device.type == 'cuda': + self.message_manager.add_warning( + "CUDA OOM during upscaling. Emptying cache. " + f"Current model (Model: {model_key}, Tile: {tile_size_msg_part}) may need reloading. " + "Consider using a smaller tile size or a smaller input video if issues persist." + ) + devicetorch.empty_cache(torch) + return None \ No newline at end of file diff --git a/modules/toolbox/message_manager.py b/modules/toolbox/message_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..21f6a6d9ff9fd345f645c7c7f255477a80d60930 --- /dev/null +++ b/modules/toolbox/message_manager.py @@ -0,0 +1,78 @@ +from datetime import datetime +from typing import List, Optional +import queue +import threading + +class MessageManager: + def __init__(self, max_messages: int = 100): + self._messages: List[str] = [] + self._max_messages = max_messages + self._message_queue = queue.Queue() + self._lock = threading.Lock() + + # ANSI-style formatting for different message types + self._formats = { + "INFO": "ℹ️", # Info icon + "SUCCESS": "✅", # Checkmark + "WARNING": "⚠️", # Warning icon + "ERROR": "❌", # Error icon + } + + def add_message(self, message: str, message_type: str = "INFO") -> None: + """Add a new message with minimal timestamp and icon.""" + # Only show hours:minutes for timestamps + timestamp = datetime.now().strftime("%H:%M") + icon = self._formats.get(message_type, "•") + + # Format filename paths to be more readable + if "Processing file" in message or "Created batch folder" in message: + message = self._format_path(message) + + formatted_message = f"{icon} {message}" + + with self._lock: + self._messages.append(formatted_message) + if len(self._messages) > self._max_messages: + self._messages.pop(0) + + def _format_path(self, message: str) -> str: + """Format file paths to be more concise and readable.""" + if "GRADIO_TEMP_DIR" in message: + # Extract just the filename from temp path + filename = message.split("\\")[-1] + return message.split(":")[0] + ": " + filename + elif "batch_" in message: + # Shorten batch folder path + return message.replace("../outputs/", "") + return message + + def add_success(self, message: str) -> None: + """Add a success message.""" + self.add_message(message, "SUCCESS") + + def add_warning(self, message: str) -> None: + """Add a warning message.""" + self.add_message(message, "WARNING") + + def add_error(self, message: str) -> None: + """Add an error message.""" + self.add_message(message, "ERROR") + + def get_messages(self) -> str: + """Get all messages as a single string with spacing between different types.""" + with self._lock: + # Add a blank line between different message types for readability + formatted = [] + last_type = None + for msg in self._messages: + current_type = next((t for t in self._formats if self._formats[t] in msg), None) + if last_type and current_type != last_type: + formatted.append("") # Add spacing between different types + formatted.append(msg) + last_type = current_type + return "\n".join(formatted) + + def clear(self) -> None: + """Clear all messages.""" + with self._lock: + self._messages.clear() diff --git a/modules/toolbox/rife_core.py b/modules/toolbox/rife_core.py new file mode 100644 index 0000000000000000000000000000000000000000..72714aa64f563b2fc4535a28545ed883df3ad767 --- /dev/null +++ b/modules/toolbox/rife_core.py @@ -0,0 +1,133 @@ +import torch +import numpy as np +from torchvision.transforms.functional import to_tensor, to_pil_image +from pathlib import Path +import os +import gc +from huggingface_hub import snapshot_download + +from .RIFE.RIFE_HDv3 import Model as RIFEBaseModel +from .message_manager import MessageManager +import devicetorch + +# Get the directory of the current script (rife_core.py) +_MODULE_DIR = Path(os.path.dirname(os.path.abspath(__file__))) # __file__ gives path to current script + +# MODEL_RIFE_PATH = "model_rife" # OLD - this is relative to CWD +MODEL_RIFE_PATH = _MODULE_DIR / "model_rife" # NEW - relative to this script's location +RIFE_MODEL_FILENAME = "flownet.pkl" + +class RIFEHandler: + def __init__(self, message_manager: MessageManager = None): + self.message_manager = message_manager if message_manager else MessageManager() + self.model_dir = Path(MODEL_RIFE_PATH) # Path() constructor handles Path objects correctly + self.model_file_path = self.model_dir / RIFE_MODEL_FILENAME + self.rife_model = None + + def _log(self, message, level="INFO"): + # Helper for logging using the MessageManager + if level.upper() == "ERROR": + self.message_manager.add_error(f"RIFEHandler: {message}") + elif level.upper() == "WARNING": + self.message_manager.add_warning(f"RIFEHandler: {message}") + else: + self.message_manager.add_message(f"RIFEHandler: {message}") + + def _ensure_model_downloaded_and_loaded(self) -> bool: + if self.rife_model is not None: + self._log("RIFE model already loaded.") + return True + + # self.model_dir is now an absolute path + if not self.model_dir.exists(): + os.makedirs(self.model_dir, exist_ok=True) + self._log(f"Created RIFE model directory: {self.model_dir}") + + # self.model_file_path is now an absolute path + if not self.model_file_path.exists(): + self._log("RIFE model weights not found. Downloading...") + try: + snapshot_download( + repo_id="AlexWortega/RIFE", + allow_patterns=["*.pkl", "*.pth"], + local_dir=self.model_dir, # Pass the absolute path + local_dir_use_symlinks=False + ) + if self.model_file_path.exists(): + self._log("RIFE model weights downloaded successfully.") + else: + self._log(f"RIFE model download completed, but {RIFE_MODEL_FILENAME} not found in {self.model_dir}. Check allow_patterns and repo structure.", "ERROR") + return False + except Exception as e: + self._log(f"Failed to download RIFE model weights: {e}", "ERROR") + return False + + if not self.model_file_path.exists(): + self._log(f"RIFE model file {self.model_file_path} does not exist. Cannot load model.", "ERROR") + return False + + try: + self._log(f"Loading RIFE model from {self.model_dir}...") # self.model_dir is absolute + current_device_str = devicetorch.get(torch) + self.rife_model = RIFEBaseModel(local_rank=-1) + + self.rife_model.load_model(str(self.model_dir), -1) # str(self.model_dir) is absolute + self.rife_model.eval() + self._log(f"RIFE model loaded successfully to its determined device.") + return True + except Exception as e: + self._log(f"Failed to load RIFE model: {e}", "ERROR") + import traceback + self._log(f"Traceback: {traceback.format_exc()}", "ERROR") + self.rife_model = None + return False + + def unload_model(self): + if self.rife_model is not None: + self._log("Unloading RIFE model...") + del self.rife_model + self.rife_model = None + devicetorch.empty_cache(torch) + gc.collect() + self._log("RIFE model unloaded and memory cleared.") + else: + self._log("RIFE model not loaded, no need to unload.") + + def interpolate_between_frames(self, frame1_np: np.ndarray, frame2_np: np.ndarray) -> np.ndarray | None: + if self.rife_model is None: + self._log("RIFE model not loaded. Call _ensure_model_downloaded_and_loaded() before interpolation.", "ERROR") + return None + + try: + img0_tensor = to_tensor(frame1_np).unsqueeze(0) + img1_tensor = to_tensor(frame2_np).unsqueeze(0) + + img0 = devicetorch.to(torch, img0_tensor) + img1 = devicetorch.to(torch, img1_tensor) + + + required_multiple = 32 + h_orig, w_orig = img0.shape[2], img0.shape[3] + pad_h = (required_multiple - h_orig % required_multiple) % required_multiple + pad_w = (required_multiple - w_orig % required_multiple) % required_multiple + + if pad_h > 0 or pad_w > 0: + img0 = torch.nn.functional.pad(img0, (0, pad_w, 0, pad_h), mode='replicate') + img1 = torch.nn.functional.pad(img1, (0, pad_w, 0, pad_h), mode='replicate') + + with torch.no_grad(): + middle_frame_tensor = self.rife_model.inference(img0, img1, scale=1.0) + + if pad_h > 0 or pad_w > 0: + middle_frame_tensor = middle_frame_tensor[:, :, :h_orig, :w_orig] + + middle_frame_pil = to_pil_image(middle_frame_tensor.squeeze(0).cpu()) + return np.array(middle_frame_pil) + + except Exception as e: + self._log(f"Error during RIFE frame interpolation: {e}", "ERROR") + import traceback + self._log(f"Traceback: {traceback.format_exc()}", "ERROR") + if "out of memory" in str(e).lower(): + devicetorch.empty_cache(torch) + return None \ No newline at end of file diff --git a/modules/toolbox/setup_ffmpeg.py b/modules/toolbox/setup_ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..08c6502c668cb490a07b4665d96155edbe85ee0b --- /dev/null +++ b/modules/toolbox/setup_ffmpeg.py @@ -0,0 +1,123 @@ +import os +import sys +import requests +import tarfile +import zipfile +import shutil +from tqdm import tqdm + +def setup_ffmpeg(): + """Download and set up a cross-platform, full build of FFmpeg and FFprobe.""" + # Get the directory of the current script, which is now inside 'modules/toolbox/' + script_dir = os.path.dirname(os.path.abspath(__file__)) + # The 'bin' directory is created directly inside this script's directory. + bin_dir = os.path.join(script_dir, 'bin') + os.makedirs(bin_dir, exist_ok=True) + + # --- Platform-specific configuration --- + if sys.platform == "win32": + platform = "windows" + ffmpeg_name = 'ffmpeg.exe' + ffprobe_name = 'ffprobe.exe' + download_url = "https://github.com/GyanD/codexffmpeg/releases/download/7.0/ffmpeg-7.0-full_build.zip" + archive_name = 'ffmpeg.zip' + # For Windows, the path is static and predictable + path_in_archive_to_bin = 'ffmpeg-7.0-full_build/bin' + elif sys.platform.startswith("linux"): + platform = "linux" + ffmpeg_name = 'ffmpeg' + ffprobe_name = 'ffprobe' + # This link always points to the latest static build + download_url = "https://johnvansickle.com/ffmpeg/releases/ffmpeg-release-amd64-static.tar.xz" + archive_name = 'ffmpeg.tar.xz' + # --- CHANGE: We no longer hardcode the path_in_archive_to_bin for Linux --- + else: + print(f"Unsupported platform: {sys.platform}") + print("Please download FFmpeg manually and place ffmpeg/ffprobe in the 'bin' directory.") + return + + ffmpeg_path = os.path.join(bin_dir, ffmpeg_name) + ffprobe_path = os.path.join(bin_dir, ffprobe_name) + + if os.path.exists(ffmpeg_path) and os.path.exists(ffprobe_path): + print(f"FFmpeg is already set up in: {bin_dir}") + return + + archive_path = os.path.join(bin_dir, archive_name) + + try: + print(f"FFmpeg not found. Downloading and setting up for {platform}...") + download_ffmpeg(download_url, archive_path) + + print("Download complete. Installing...") + temp_extract_dir = os.path.join(bin_dir, 'temp_ffmpeg_extract') + os.makedirs(temp_extract_dir, exist_ok=True) + + if archive_name.endswith('.zip'): + with zipfile.ZipFile(archive_path, 'r') as archive: + archive.extractall(path=temp_extract_dir) + elif archive_name.endswith('.tar.xz'): + with tarfile.open(archive_path, 'r:xz') as archive: + archive.extractall(path=temp_extract_dir) + + # --- ROBUSTNESS CHANGE FOR LINUX --- + # Dynamically find the path to the binaries instead of hardcoding it. + if platform == "linux": + # Find the single subdirectory inside the extraction folder + subdirs = [d for d in os.listdir(temp_extract_dir) if os.path.isdir(os.path.join(temp_extract_dir, d))] + if len(subdirs) != 1: + raise Exception(f"Expected one subdirectory in Linux FFmpeg archive, but found {len(subdirs)}.") + # The binaries are directly inside this discovered folder + source_bin_dir = os.path.join(temp_extract_dir, subdirs[0]) + else: # For Windows, we use the predefined path + source_bin_dir = os.path.join(temp_extract_dir, path_in_archive_to_bin) + + # Find the executables in the now correctly identified source folder and copy them + source_ffmpeg_path = os.path.join(source_bin_dir, ffmpeg_name) + source_ffprobe_path = os.path.join(source_bin_dir, ffprobe_name) + + if not os.path.exists(source_ffmpeg_path) or not os.path.exists(source_ffprobe_path): + raise FileNotFoundError(f"Could not find ffmpeg/ffprobe in the expected location: {source_bin_dir}") + + shutil.copy(source_ffmpeg_path, ffmpeg_path) + shutil.copy(source_ffprobe_path, ffprobe_path) + + if platform == "linux": + os.chmod(ffmpeg_path, 0o755) + os.chmod(ffprobe_path, 0o755) + + print(f"✅ FFmpeg setup complete. Binaries are in: {bin_dir}") + + except Exception as e: + print(f"\n❌ Error setting up FFmpeg: {e}") + import traceback + traceback.print_exc() + print("\nPlease download FFmpeg manually and place the 'ffmpeg' and 'ffprobe' executables in the 'bin' directory.") + print(f"Download for Windows: https://www.gyan.dev/ffmpeg/builds/") + print(f"Download for Linux: https://johnvansickle.com/ffmpeg/") + finally: + # Clean up + if os.path.exists(archive_path): + os.remove(archive_path) + if 'temp_extract_dir' in locals() and os.path.exists(temp_extract_dir): + shutil.rmtree(temp_extract_dir) + +def download_ffmpeg(url, destination): + """Download a file with progress bar""" + response = requests.get(url, stream=True) + response.raise_for_status() # Raise an exception for bad status codes + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 + + # The calling function now handles the initial "Downloading..." message. + # This keeps the download function focused on its single responsibility. + with open(destination, 'wb') as file, tqdm( + desc=os.path.basename(destination), # Use basename for a cleaner progress bar + total=total_size, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in response.iter_content(block_size): + size = file.write(data) + bar.update(size) \ No newline at end of file diff --git a/modules/toolbox/system_monitor.py b/modules/toolbox/system_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3e8c1b4526ac0f1489617695addb079fe35103 --- /dev/null +++ b/modules/toolbox/system_monitor.py @@ -0,0 +1,312 @@ +import platform +import subprocess +import os +import psutil +import torch +from typing import Optional, Dict, Tuple, Union + +NumericValue = Union[int, float] +MetricsDict = Dict[str, NumericValue] + +class SystemMonitor: + @staticmethod + def get_nvidia_gpu_info() -> Tuple[str, MetricsDict, Optional[str]]: + """Get NVIDIA GPU information and metrics for GPU 0.""" + metrics = {} + gpu_name_from_torch = "NVIDIA GPU (name unavailable)" + warning_message = None # To indicate if nvidia-smi failed and we're using PyTorch fallback + + try: + gpu_name_from_torch = f"{torch.cuda.get_device_name(0)}" + except Exception: + # If even the name fails, nvidia-smi is highly likely to fail too. + # Prepare basic PyTorch metrics as the ultimate fallback. + metrics = { + 'memory_used_gb': torch.cuda.memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0, + 'memory_total_gb': torch.cuda.get_device_properties(0).total_memory / 1024**3 if torch.cuda.is_available() else 0, + # Add placeholders for other metrics to maintain UI symmetry if nvidia-smi fails + 'memory_reserved_gb': 0.0, + 'temperature': 0.0, + 'utilization': 0.0 + } + warning_message = "Could not get GPU name via PyTorch. nvidia-smi likely to fail or has failed. Displaying basic PyTorch memory (application-specific)." + return gpu_name_from_torch, metrics, warning_message + + # Query for memory.used, memory.total, memory.reserved, temperature.gpu, utilization.gpu + nvidia_smi_common_args = [ + 'nvidia-smi', + '--query-gpu=memory.used,memory.total,memory.reserved,temperature.gpu,utilization.gpu', + '--format=csv,nounits,noheader' + ] + + smi_output_str = None + try: + # Attempt 1: Query specific GPU 0 + smi_output_str = subprocess.check_output( + nvidia_smi_common_args + ['--id=0'], + encoding='utf-8', timeout=1.5, stderr=subprocess.PIPE + ) + except (subprocess.SubprocessError, FileNotFoundError, ValueError) as e1: + # print(f"nvidia-smi with --id=0 failed: {type(e1).__name__}. Trying general query.") + try: + # Attempt 2: Query all GPUs and parse the first line + smi_output_str = subprocess.check_output( + nvidia_smi_common_args, # Without --id=0 + encoding='utf-8', timeout=1.5, stderr=subprocess.PIPE + ) + if smi_output_str: + smi_output_str = smi_output_str.strip().split('\n')[0] # Take the first line + except (subprocess.SubprocessError, FileNotFoundError, ValueError) as e2: + # print(f"nvidia-smi (general query) also failed: {type(e2).__name__}. Falling back to torch.cuda.") + # Fallback to basic CUDA info from PyTorch, plus placeholders for UI + metrics = { + 'memory_used_gb': torch.cuda.memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0, + 'memory_total_gb': torch.cuda.get_device_properties(0).total_memory / 1024**3 if torch.cuda.is_available() else 0, + 'memory_reserved_gb': 0.0, # Placeholder + 'temperature': 0.0, # Placeholder + 'utilization': 0.0 # Placeholder + } + warning_message = "nvidia-smi failed. GPU Memory Used is PyTorch specific (not total). Other GPU stats unavailable." + return gpu_name_from_torch, metrics, warning_message + + if smi_output_str: + parts = smi_output_str.strip().split(',') + if len(parts) == 5: # memory.used, memory.total, memory.reserved, temp, util + memory_used_mib, memory_total_mib, memory_reserved_mib, temp, util = map(float, parts) + metrics = { + 'memory_used_gb': memory_used_mib / 1024, + 'memory_total_gb': memory_total_mib / 1024, + 'memory_reserved_gb': memory_reserved_mib / 1024, # This is from nvidia-smi's memory.reserved + 'temperature': temp, + 'utilization': util + } + else: + # print(f"Unexpected nvidia-smi output format: {smi_output_str}. Parts: {len(parts)}") + warning_message = "nvidia-smi output format unexpected. Some GPU stats may be missing or inaccurate." + # Fallback with placeholders to maintain UI structure + metrics = { + 'memory_used_gb': torch.cuda.memory_allocated(0) / 1024**3 if torch.cuda.is_available() else 0, # PyTorch fallback + 'memory_total_gb': torch.cuda.get_device_properties(0).total_memory / 1024**3 if torch.cuda.is_available() else 0, # PyTorch fallback + 'memory_reserved_gb': 0.0, + 'temperature': 0.0, + 'utilization': 0.0 + } + if len(parts) >= 2: # Try to parse what we can if format is just partially off + try: metrics['memory_used_gb'] = float(parts[0]) / 1024 + except: pass + try: metrics['memory_total_gb'] = float(parts[1]) / 1024 + except: pass + else: # Should have been caught by try-except, but as a final safety + metrics = { + 'memory_used_gb': 0.0, 'memory_total_gb': 0.0, 'memory_reserved_gb': 0.0, + 'temperature': 0.0, 'utilization': 0.0 + } + warning_message = "Failed to get any output from nvidia-smi." + + + return gpu_name_from_torch, metrics, warning_message + + @staticmethod + def get_mac_gpu_info() -> Tuple[str, MetricsDict, Optional[str]]: # Add warning return for consistency + """Get Apple Silicon GPU information without requiring sudo.""" + metrics = {} + warning_message = None + try: + memory = psutil.virtual_memory() + metrics = { + 'memory_total_gb': memory.total / (1024**3), + 'memory_used_gb': memory.used / (1024**3), # This is system RAM, reported as "Unified Memory" + 'utilization': psutil.cpu_percent(), # Still CPU usage as proxy + # Placeholders for Mac to match NVIDIA's output structure for UI symmetry + 'memory_reserved_gb': 0.0, # N/A for unified memory in this context + 'temperature': 0.0 # Not easily available without sudo + } + if metrics['utilization'] == psutil.cpu_percent(): # Check if it's actually CPU util + warning_message = "Mac GPU Load is proxied by CPU Usage." + + except Exception as e: + # print(f"Error getting Mac info: {e}") + metrics = { + 'memory_total_gb': 0.0, 'memory_used_gb': 0.0, 'utilization': 0.0, + 'memory_reserved_gb': 0.0, 'temperature': 0.0 + } + warning_message = "Could not retrieve Mac system info." + + return "Apple Silicon GPU", metrics, warning_message # Changed name for clarity + + @staticmethod + def get_amd_gpu_info() -> Tuple[str, MetricsDict, Optional[str]]: # Add warning return + """Get AMD GPU information.""" + metrics = { # Initialize with placeholders for all expected keys for UI symmetry + 'memory_used_gb': 0.0, + 'memory_total_gb': 0.0, + 'memory_reserved_gb': 0.0, # Typically N/A or not reported by rocm-smi in a 'reserved' sense + 'temperature': 0.0, + 'utilization': 0.0 + } + warning_message = None + source = "unknown" + + try: + # Try rocm-smi first + try: + result = subprocess.check_output(['rocm-smi', '--showmeminfo', 'vram', '--showtemp', '--showuse'], encoding='utf-8', timeout=1.5, stderr=subprocess.PIPE) + # Example rocm-smi output parsing (highly dependent on actual output format) + # This needs to be robust or use a more structured output format like --json if rocm-smi supports it + # For VRAM Used/Total: + # GPU[0] VRAM Usage: 2020M/16368M + # For Temp: + # GPU[0] Temperature: 34c + # For Util: + # GPU[0] GPU Use: 0% + lines = result.strip().split('\n') + for line in lines: + if line.startswith("GPU[0]"): # Assuming card 0 + if "VRAM Usage:" in line: + mem_parts = line.split("VRAM Usage:")[1].strip().split('/') + metrics['memory_used_gb'] = float(mem_parts[0].replace('M', '')) / 1024 + metrics['memory_total_gb'] = float(mem_parts[1].replace('M', '')) / 1024 + source = "rocm-smi" + elif "Temperature:" in line: + metrics['temperature'] = float(line.split("Temperature:")[1].strip().replace('c', '')) + source = "rocm-smi" + elif "GPU Use:" in line: + metrics['utilization'] = float(line.split("GPU Use:")[1].strip().replace('%', '')) + source = "rocm-smi" + if source != "rocm-smi": # if parsing failed or fields were missing + warning_message = "rocm-smi ran but output parsing failed." + except (subprocess.SubprocessError, FileNotFoundError, ValueError) as e_rocm: + # print(f"rocm-smi failed: {e_rocm}. Trying sysfs.") + warning_message = "rocm-smi not found or failed. " + # Try sysfs as fallback on Linux + if platform.system() == "Linux": + base_path = "/sys/class/drm/card0/device" # This assumes card0 + sysfs_found_any = False + try: + with open(f"{base_path}/hwmon/hwmon0/temp1_input") as f: # Check for specific hwmon index + metrics['temperature'] = float(f.read().strip()) / 1000 + sysfs_found_any = True + except (FileNotFoundError, PermissionError, ValueError): pass # Ignore if specific file not found + + try: + with open(f"{base_path}/mem_info_vram_total") as f: + metrics['memory_total_gb'] = int(f.read().strip()) / (1024**3) # Bytes to GiB + with open(f"{base_path}/mem_info_vram_used") as f: + metrics['memory_used_gb'] = int(f.read().strip()) / (1024**3) # Bytes to GiB + sysfs_found_any = True + except (FileNotFoundError, PermissionError, ValueError): pass + + try: + with open(f"{base_path}/gpu_busy_percent") as f: + metrics['utilization'] = float(f.read().strip()) + sysfs_found_any = True + except (FileNotFoundError, PermissionError, ValueError): pass + + if sysfs_found_any: + source = "sysfs" + warning_message += "Using sysfs (may be incomplete)." + else: + warning_message += "sysfs also failed or provided no data." + else: + warning_message += "Not on Linux, no sysfs fallback." + + except Exception as e_amd_main: # Catch-all for unforeseen issues in AMD block + # print(f"General error in get_amd_gpu_info: {e_amd_main}") + warning_message = (warning_message or "") + " Unexpected error in AMD GPU info gathering." + + return f"AMD GPU ({source})", metrics, warning_message + + @staticmethod + def is_amd_gpu() -> bool: # No changes needed here + try: + # Check for rocm-smi first as it's more definitive + rocm_smi_exists = False + try: + subprocess.check_call(['rocm-smi', '-h'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=0.5) + rocm_smi_exists = True + except (subprocess.SubprocessError, FileNotFoundError): + pass # rocm-smi not found or errored + + if rocm_smi_exists: + return True + + # Fallback to sysfs check if on Linux + if platform.system() == "Linux" and os.path.exists('/sys/class/drm/card0/device/vendor'): + with open('/sys/class/drm/card0/device/vendor', 'r') as f: + return '0x1002' in f.read() # AMD's PCI vendor ID + return False + except: + return False + + @classmethod + def get_system_info(cls) -> str: + """Get detailed system status with support for different GPU types.""" + gpu_name_display: Optional[str] = None + metrics: MetricsDict = {} + gpu_warning: Optional[str] = None + + try: + # Determine GPU type and get metrics + if torch.cuda.is_available(): # Implies NVIDIA usually + gpu_name_display, metrics, gpu_warning = cls.get_nvidia_gpu_info() + elif platform.system() == "Darwin" and platform.processor() == "arm": # Apple Silicon + gpu_name_display, metrics, gpu_warning = cls.get_mac_gpu_info() + elif cls.is_amd_gpu(): # Check for AMD (works on Linux, might need refinement for Windows if not using PyTorch ROCm) + gpu_name_display, metrics, gpu_warning = cls.get_amd_gpu_info() + else: # No specific GPU detected by these primary checks + # Could add a PyTorch ROCm check here if desired for AMD on Windows/Linux without rocm-smi + # if hasattr(torch, 'rocm_is_available') and torch.rocm_is_available(): + # gpu_name_display = "AMD GPU (via PyTorch ROCm)" + # metrics = { ... basic torch.rocm metrics ... } + pass + + + # Format GPU info based on available metrics + if gpu_name_display: + gpu_info_lines = [f"🎮 GPU: {gpu_name_display}"] + + # Standard memory reporting + if 'memory_used_gb' in metrics and 'memory_total_gb' in metrics: + mem_label = "GPU Memory" + if platform.system() == "Darwin" and platform.processor() == "arm": + mem_label = "Unified Memory" # For Apple Silicon + + gpu_info_lines.append( + f"📊 {mem_label}: {metrics.get('memory_used_gb', 0.0):.1f}GB / {metrics.get('memory_total_gb', 0.0):.1f}GB" + ) + + # VRAM Reserved (NVIDIA specific from nvidia-smi, or placeholder) + # if 'memory_reserved_gb' in metrics and torch.cuda.is_available() and not (platform.system() == "Darwin"): # Show for NVIDIA, not Mac + # gpu_info_lines.append(f"💾 VRAM Reserved: {metrics.get('memory_reserved_gb', 0.0):.1f}GB") + + if 'temperature' in metrics and metrics.get('temperature', 0.0) > 0: # Only show if temp is valid + gpu_info_lines.append(f"🌡️ GPU Temp: {metrics.get('temperature', 0.0):.0f}°C") + + if 'utilization' in metrics: + gpu_info_lines.append(f"⚡ GPU Load: {metrics.get('utilization', 0.0):.0f}%") + + if gpu_warning: # Display any warning from the GPU info functions + gpu_info_lines.append(f"⚠️ {gpu_warning}") + + gpu_section = "\n".join(gpu_info_lines) + "\n" + else: + gpu_section = "🎮 GPU: No dedicated GPU detected or supported\n" + + # Get CPU info + cpu_count = psutil.cpu_count(logical=False) # Physical cores + cpu_threads = psutil.cpu_count(logical=True) # Logical processors + cpu_info = f"💻 CPU: {cpu_count or 'N/A'} Cores, {cpu_threads or 'N/A'} Threads\n" + cpu_usage = f"⚡ CPU Usage: {psutil.cpu_percent()}%\n" + + # Get RAM info + ram = psutil.virtual_memory() + ram_used_gb = ram.used / (1024**3) + ram_total_gb = ram.total / (1024**3) + ram_info = f"🎯 System RAM: {ram_used_gb:.1f}GB / {ram_total_gb:.1f}GB ({ram.percent}%)" + + return f"{gpu_section}{cpu_info}{cpu_usage}{ram_info}" + + except Exception as e: + # print(f"Overall error in get_system_info: {e}") + # import traceback; print(traceback.format_exc()) + return f"Error collecting system info: {str(e)}" \ No newline at end of file diff --git a/modules/toolbox/toolbox_processor.py b/modules/toolbox/toolbox_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb0684c354b3793596861e580e7ae2d0ff6f33e --- /dev/null +++ b/modules/toolbox/toolbox_processor.py @@ -0,0 +1,1831 @@ +import os +import gc +import sys +import re +import numpy as np +import torch +import imageio +import gradio as gr +import subprocess +import devicetorch +import json +import math +import shutil +import traceback + +from datetime import datetime +from pathlib import Path +from huggingface_hub import snapshot_download +from tqdm.auto import tqdm + +from torchvision.transforms.functional import to_tensor, to_pil_image + +from modules.toolbox.rife_core import RIFEHandler +from modules.toolbox.esrgan_core import ESRGANUpscaler +from modules.toolbox.message_manager import MessageManager + +device_name_str = devicetorch.get(torch) + +VIDEO_QUALITY = 8 # Used by imageio.mimwrite quality/quantizer + +class VideoProcessor: + def __init__(self, message_manager: MessageManager, settings): + self.message_manager = message_manager + self.rife_handler = RIFEHandler(message_manager) + self.device_obj = torch.device(device_name_str) # Store device_obj + self.esrgan_upscaler = ESRGANUpscaler(message_manager, self.device_obj) + self.settings = settings + self.project_root = Path(__file__).resolve().parents[2] + + # FFmpeg/FFprobe paths and status flags + self.ffmpeg_exe = None + self.ffprobe_exe = None + self.has_ffmpeg = False + self.has_ffprobe = False + self.ffmpeg_source = None + self.ffprobe_source = None + + self._tb_initialize_ffmpeg() # Finds executables and sets flags + + studio_output_dir = Path(self.settings.get("output_dir")) + self.postprocessed_output_root_dir = studio_output_dir / "postprocessed_output" + self._base_temp_output_dir = self.postprocessed_output_root_dir / "temp_processing" + self._base_permanent_save_dir = self.postprocessed_output_root_dir / "saved_videos" + + self.toolbox_video_output_dir = self._base_temp_output_dir + self.toolbox_permanent_save_dir = self._base_permanent_save_dir + + os.makedirs(self.postprocessed_output_root_dir, exist_ok=True) + os.makedirs(self._base_temp_output_dir, exist_ok=True) + os.makedirs(self._base_permanent_save_dir, exist_ok=True) + + # Note: Renamed to a more generic name as it holds more than just extracted frames now + self.frames_io_dir = self.postprocessed_output_root_dir / "frames" + self.extracted_frames_target_path = self.frames_io_dir / "extracted_frames" + os.makedirs(self.extracted_frames_target_path, exist_ok=True) + self.reassembled_video_target_path = self.frames_io_dir / "reassembled_videos" + os.makedirs(self.reassembled_video_target_path, exist_ok=True) + + # --- NEW BATCH PROCESSING FUNCTION --- + def tb_process_video_batch(self, video_paths: list, pipeline_config: dict, progress=gr.Progress()): + """ + Processes a batch of videos according to a defined pipeline of operations. + - Batch jobs are ALWAYS saved to a new, unique, timestamped folder in 'saved_videos'. + - Single video pipeline jobs respect the 'Autosave' setting for the FINAL output only. + - Intermediate files are always created in and cleaned from the temp directory. + - The very last successfully processed video (from single or batch) is kept for the UI. + """ + original_autosave_state = self.settings.get("toolbox_autosave_enabled", True) + is_batch_job = len(video_paths) > 1 + batch_output_dir = None + last_successful_video_path_for_ui = None + + try: + if is_batch_job: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + batch_output_dir = self._base_permanent_save_dir / f"batch_process_{timestamp}" + os.makedirs(batch_output_dir, exist_ok=True) + self.message_manager.add_message(f"Batch outputs will be saved to: {batch_output_dir}", "SUCCESS") + + self.set_autosave_mode(False, silent=True) + + operations = pipeline_config.get("operations", []) + if not operations: + self.message_manager.add_warning("No operations were selected for the pipeline. Nothing to do.") + return None + + op_names = [op['name'].replace('_', ' ').title() for op in operations] + self.message_manager.add_message(f"🚀 Starting pipeline for {len(video_paths)} videos. Pipeline: {' -> '.join(op_names)}") + + total_videos = len(video_paths) + + for i, original_video_path in enumerate(video_paths): + progress(i / total_videos, desc=f"Video {i+1}/{total_videos}: {os.path.basename(original_video_path)}") + self.message_manager.add_message(f"\n--- Processing Video {i+1}/{total_videos}: {os.path.basename(original_video_path)} ---", "INFO") + + current_video_path = original_video_path + video_failed = False + path_to_clean = None + + for op_config in operations: + op_name = op_config["name"] + op_params = op_config["params"] + + self.message_manager.add_message(f" -> Step: Applying {op_name.replace('_', ' ')}...") + output_path = None + try: + if op_name == "upscale": output_path = self.tb_upscale_video(current_video_path, **op_params, progress=progress) + elif op_name == "frame_adjust": output_path = self.tb_process_frames(current_video_path, **op_params, progress=progress) + elif op_name == "filters": output_path = self.tb_apply_filters(current_video_path, **op_params, progress=progress) + elif op_name == "loop": output_path = self.tb_create_loop(current_video_path, **op_params, progress=progress) + elif op_name == "export": output_path = self.tb_export_video(current_video_path, **op_params, progress=progress) + + if output_path and os.path.exists(output_path): + self.message_manager.add_success(f" -> Step '{op_name}' completed. Output: {os.path.basename(output_path)}") + if path_to_clean and os.path.exists(path_to_clean): + try: + os.remove(path_to_clean) + self.message_manager.add_message(f" -> Cleaned intermediate file: {os.path.basename(path_to_clean)}", "DEBUG") + except OSError as e: + self.message_manager.add_warning(f"Could not clean intermediate file {path_to_clean}: {e}") + + current_video_path = output_path + path_to_clean = output_path + else: + video_failed = True; break + except Exception as e: + video_failed = True + self.message_manager.add_error(f"An unexpected error occurred during step '{op_name}': {e}") + self.message_manager.add_error(traceback.format_exc()) + break + + if not video_failed: + final_temp_path = current_video_path + is_last_video_in_batch = (i == total_videos - 1) + + if is_batch_job: + # For batch jobs, copy the final output to the permanent batch folder. + final_dest_path = batch_output_dir / os.path.basename(final_temp_path) + shutil.copy2(final_temp_path, final_dest_path) # Use copy2 to keep temp file for UI + self.message_manager.add_success(f"--- Successfully processed. Final output saved to: {final_dest_path} ---") + + if is_last_video_in_batch: + # This is the very last video of the whole batch, keep its temp path for the UI player. + last_successful_video_path_for_ui = final_temp_path + else: + # This is a completed video but not the last one in the batch, so we can clean its temp file. + try: os.remove(final_temp_path) + except OSError: pass + else: # Single video pipeline run. + if original_autosave_state: + final_dest_path = self._base_permanent_save_dir / os.path.basename(final_temp_path) + shutil.move(final_temp_path, final_dest_path) # Move, as it's saved permanently + self.message_manager.add_success(f"--- Successfully processed. Final output saved to: {final_dest_path} ---") + last_successful_video_path_for_ui = final_dest_path + else: + # Autosave off, so the final file remains in the temp folder for the UI. + self.message_manager.add_success(f"--- Successfully processed. Final output is in temp folder: {final_temp_path} ---") + last_successful_video_path_for_ui = final_temp_path + else: + self.message_manager.add_warning(f"--- Processing failed for {os.path.basename(original_video_path)} ---") + if path_to_clean and os.path.exists(path_to_clean): + try: os.remove(path_to_clean) + except OSError as e: self.message_manager.add_warning(f"Could not clean failed intermediate file {path_to_clean}: {e}") + + gc.collect() + devicetorch.empty_cache(torch) + + progress(1.0, desc="Pipeline complete.") + self.message_manager.add_message("\n✅ Pipeline processing finished.", "SUCCESS") + return last_successful_video_path_for_ui + + finally: + # Restore the user's original autosave setting silently. + self.set_autosave_mode(original_autosave_state, silent=True) + + def _tb_initialize_ffmpeg(self): + """Finds FFmpeg/FFprobe and sets status flags and sources.""" + ( + self.ffmpeg_exe, + self.ffmpeg_source, + self.ffprobe_exe, + self.ffprobe_source, + ) = self._tb_find_ffmpeg_executables() + + self.has_ffmpeg = bool(self.ffmpeg_exe) + self.has_ffprobe = bool(self.ffprobe_exe) + + self._report_ffmpeg_status() + + def _tb_find_ffmpeg_executables(self): + """ + Finds ffmpeg and ffprobe with a priority system. + Priority: 1. Bundled -> 2. System PATH -> 3. imageio-ffmpeg + Returns (ffmpeg_path, ffmpeg_source, ffprobe_path, ffprobe_source) + """ + ffmpeg_path, ffprobe_path = None, None + ffmpeg_source, ffprobe_source = None, None + ffmpeg_name = "ffmpeg.exe" if sys.platform == "win32" else "ffmpeg" + ffprobe_name = "ffprobe.exe" if sys.platform == "win32" else "ffprobe" + + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + bin_dir = os.path.join(script_dir, 'bin') + bundled_ffmpeg = os.path.join(bin_dir, ffmpeg_name) + bundled_ffprobe = os.path.join(bin_dir, ffprobe_name) + if os.path.exists(bundled_ffmpeg): + ffmpeg_path = bundled_ffmpeg + ffmpeg_source = "Bundled" + if os.path.exists(bundled_ffprobe): + ffprobe_path = bundled_ffprobe + ffprobe_source = "Bundled" + except Exception: + pass + + if not ffmpeg_path: + path_from_env = shutil.which(ffmpeg_name) + if path_from_env: + ffmpeg_path = path_from_env + ffmpeg_source = "System PATH" + if not ffprobe_path: + path_from_env = shutil.which(ffprobe_name) + if path_from_env: + ffprobe_path = path_from_env + ffprobe_source = "System PATH" + + if not ffmpeg_path: + try: + imageio_ffmpeg_exe = imageio.plugins.ffmpeg.get_exe() + if os.path.isfile(imageio_ffmpeg_exe): + ffmpeg_path = imageio_ffmpeg_exe + ffmpeg_source = "imageio-ffmpeg" + except Exception: + pass + + return ffmpeg_path, ffmpeg_source, ffprobe_path, ffprobe_source + + def _report_ffmpeg_status(self): + """Provides a summary of FFmpeg/FFprobe status based on what was found.""" + if self.ffmpeg_source == "Bundled" and self.ffprobe_source == "Bundled": + self.message_manager.add_message(f"Bundled FFmpeg found: {self.ffmpeg_exe}", "SUCCESS") + self.message_manager.add_message(f"Bundled FFprobe found: {self.ffprobe_exe}", "SUCCESS") + self.message_manager.add_message("All video and audio features are enabled.", "SUCCESS") + return + + if self.has_ffmpeg: + self.message_manager.add_message(f"FFmpeg found via {self.ffmpeg_source}: {self.ffmpeg_exe}", "SUCCESS") + else: + self.message_manager.add_error("Critical: FFmpeg executable could not be found. Most video processing operations will fail. Please try running the setup script.") + + if self.has_ffprobe: + self.message_manager.add_message(f"FFprobe found via {self.ffprobe_source}: {self.ffprobe_exe}", "SUCCESS") + else: + self.message_manager.add_warning("FFprobe not found. Audio detection and full video analysis will be limited.") + if self.ffmpeg_source != "Bundled": + self.message_manager.add_warning("For full functionality, please run the 'setup_ffmpeg.py' script.") + + def tb_get_frames_from_folder(self, folder_name: str) -> list: + """ + Gets a sorted list of image file paths from a given folder name. + This is the backend for the "Load Frames to Studio" button. + """ + if not folder_name: + return [] + + full_folder_path = os.path.join(self.extracted_frames_target_path, folder_name) + if not os.path.isdir(full_folder_path): + self.message_manager.add_error(f"Cannot load frames: Directory not found at {full_folder_path}") + return [] + + frame_files = [] + try: + for filename in os.listdir(full_folder_path): + if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')): + frame_files.append(os.path.join(full_folder_path, filename)) + + # Natural sort to handle frame_0, frame_1, ... frame_10 correctly + def natural_sort_key(s): + return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)] + + frame_files.sort(key=natural_sort_key) + return frame_files + + except Exception as e: + self.message_manager.add_error(f"Error reading frames from '{folder_name}': {e}") + return [] + + def tb_delete_single_frame(self, frame_path_to_delete: str) -> str: + """Deletes a single frame file from the disk, logs the action, and returns a status message.""" + if not frame_path_to_delete or not isinstance(frame_path_to_delete, str): + # This message is returned to the app's info box + msg_for_infobox = "Error: Invalid frame path provided for deletion." + # The message manager gets a more detailed log entry + self.message_manager.add_error("Could not delete frame: Invalid path provided to processor.") + return msg_for_infobox + + try: + filename = os.path.basename(frame_path_to_delete) + if os.path.isfile(frame_path_to_delete): + os.remove(frame_path_to_delete) + # Add a success message to the main log + self.message_manager.add_success(f"Deleted frame: {filename}") + # Return a concise status for the info box + return f"✅ Deleted: {filename}" + else: + self.message_manager.add_error(f"Could not delete frame. File not found: {frame_path_to_delete}") + return f"Error: Frame not found" + except OSError as e: + self.message_manager.add_error(f"Error deleting frame {filename}: {e}") + return f"Error deleting frame: {e}" + + def tb_save_single_frame(self, source_frame_path: str) -> str | None: + """Saves a copy of a single frame to the permanent 'saved_videos' directory.""" + if not source_frame_path or not os.path.isfile(source_frame_path): + self.message_manager.add_error("Source frame to save does not exist or is invalid.") + return None + + try: + source_path_obj = Path(source_frame_path) + parent_folder_name = source_path_obj.parent.name + frame_filename = source_path_obj.name + + # Create a descriptive filename to avoid collisions + timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + dest_filename = f"saved_frame_{parent_folder_name}_{timestamp}_{frame_filename}" + + destination_path = os.path.join(self.toolbox_permanent_save_dir, dest_filename) + + os.makedirs(self.toolbox_permanent_save_dir, exist_ok=True) + shutil.copy2(source_frame_path, destination_path) + + self.message_manager.add_success(f"Saved frame to permanent storage: {destination_path}") + return destination_path + except Exception as e: + self.message_manager.add_error(f"Error saving frame to permanent storage: {e}") + self.message_manager.add_error(traceback.format_exc()) + return None + + + def set_autosave_mode(self, autosave_enabled: bool, silent: bool = False): + if autosave_enabled: + self.toolbox_video_output_dir = self._base_permanent_save_dir + if not silent: + self.message_manager.add_message("Autosave ENABLED: Processed videos will be saved to the permanent folder.", "SUCCESS") + else: + self.toolbox_video_output_dir = self._base_temp_output_dir + if not silent: + self.message_manager.add_message("Autosave DISABLED: Processed videos will be saved to the temporary folder.", "INFO") + + def _tb_log_ffmpeg_error(self, e_ffmpeg: subprocess.CalledProcessError, operation_description: str): + self.message_manager.add_error(f"FFmpeg failed during {operation_description}.") + ffmpeg_stderr_str = e_ffmpeg.stderr.strip() if e_ffmpeg.stderr else "" + ffmpeg_stdout_str = e_ffmpeg.stdout.strip() if e_ffmpeg.stdout else "" + + details_log = [] + if ffmpeg_stderr_str: details_log.append(f"FFmpeg Stderr: {ffmpeg_stderr_str}") + if ffmpeg_stdout_str: details_log.append(f"FFmpeg Stdout: {ffmpeg_stdout_str}") + + if details_log: + self.message_manager.add_message("FFmpeg Output:\n" + "\n".join(details_log), "INFO") + else: + self.message_manager.add_message(f"No specific output from FFmpeg. (Return code: {e_ffmpeg.returncode}, Command: '{e_ffmpeg.cmd}')", "INFO") + + def _tb_get_video_frame_count(self, video_path: str) -> int | None: + """ + Uses ffprobe to get an accurate frame count by requesting JSON output for robust parsing. + Tries a fast metadata read first, then falls back to a slower but more accurate full stream count. + """ + if not self.has_ffprobe: + self.message_manager.add_message("Cannot get frame count: ffprobe not found.", "DEBUG") + return None + + # --- Tier 1: Fast metadata read using JSON output --- + try: + ffprobe_cmd_fast = [ + self.ffprobe_exe, + "-v", "error", + "-probesize", "5M", # Input option + "-i", video_path, # The input file + "-select_streams", "v:0", + "-show_entries", "stream=nb_frames", + "-of", "json" + ] + result = subprocess.run(ffprobe_cmd_fast, capture_output=True, text=True, check=True, errors='ignore') + data = json.loads(result.stdout) + frame_count_str = data.get("streams", [{}])[0].get("nb_frames", "N/A") + + if frame_count_str.isdigit() and int(frame_count_str) > 0: + self.message_manager.add_message(f"Frame count from metadata: {frame_count_str}", "DEBUG") + return int(frame_count_str) + else: + self.message_manager.add_warning(f"Fast metadata frame count was invalid ('{frame_count_str}'). Falling back to full count.") + except Exception as e: + self.message_manager.add_warning(f"Fast metadata read failed: {e}. Falling back to full count.") + + # --- Tier 2: Slow, accurate full-stream count using JSON output --- + try: + self.message_manager.add_message("Performing full, accurate frame count with ffprobe (this may take a moment)...", "INFO") + ffprobe_cmd_accurate = [ + self.ffprobe_exe, + "-v", "error", + "-probesize", "5M", # Input option + "-i", video_path, # The input file + "-count_frames", + "-select_streams", "v:0", + "-show_entries", "stream=nb_read_frames", + "-of", "json" + ] + result = subprocess.run(ffprobe_cmd_accurate, capture_output=True, text=True, check=True, errors='ignore') + data = json.loads(result.stdout) + frame_count_str = data.get("streams", [{}])[0].get("nb_read_frames", "N/A") + + if frame_count_str.isdigit() and int(frame_count_str) > 0: + self.message_manager.add_message(f"Accurate frame count from full scan: {frame_count_str}", "DEBUG") + return int(frame_count_str) + else: + self.message_manager.add_error(f"Full ffprobe scan returned invalid frame count: '{frame_count_str}'.") + return None + except Exception as e: + self.message_manager.add_error(f"Critical error during full ffprobe frame count: {e}") + self.message_manager.add_error(traceback.format_exc()) + return None + + def tb_extract_frames(self, video_path, extraction_rate, progress=gr.Progress()): + if video_path is None: + self.message_manager.add_warning("No input video for frame extraction.") + return None + if not isinstance(extraction_rate, int) or extraction_rate < 1: + self.message_manager.add_error("Extraction rate must be a positive integer (1 for all frames, N for every Nth frame).") + return None + + resolved_video_path = str(Path(video_path).resolve()) + output_folder_name = self._tb_generate_output_folder_path( + resolved_video_path, + suffix=f"extracted_every_{extraction_rate}") + os.makedirs(output_folder_name, exist_ok=True) + + self.message_manager.add_message( + f"Starting frame extraction for {os.path.basename(resolved_video_path)} (every {extraction_rate} frame(s))." + ) + self.message_manager.add_message(f"Outputting to: {output_folder_name}") + + reader = None + try: + total_frames = self._tb_get_video_frame_count(resolved_video_path) + + # If we know the total frames, we can provide accurate progress. + if total_frames: + progress(0, desc=f"Extracting 0 / {total_frames} frames...") + else: + self.message_manager.add_warning("Could not determine total frames. Progress will be indeterminate.") + progress(0, desc="Extracting frames (total unknown)...") + + reader = imageio.get_reader(resolved_video_path) + extracted_count = 0 + + # --- MANUAL PROGRESS LOOP --- + for i, frame in enumerate(reader): + # Update progress manually every few frames to avoid overwhelming the UI + if total_frames and i % 10 == 0: + progress(i / total_frames, desc=f"Extracting {i} / {total_frames} frames...") + + if i % extraction_rate == 0: + frame_filename = f"frame_{extracted_count:06d}.png" + output_frame_path = os.path.join(output_folder_name, frame_filename) + imageio.imwrite(output_frame_path, frame, format='PNG') + extracted_count += 1 + + # --- FINAL UPDATE --- + progress(1.0, desc="Extraction complete.") + self.message_manager.add_success(f"Successfully extracted {extracted_count} frames to: {output_folder_name}") + return output_folder_name + + except Exception as e: + self.message_manager.add_error(f"Error during frame extraction: {e}") + self.message_manager.add_error(traceback.format_exc()) + progress(1.0, desc="Error during extraction.") + return None + finally: + if reader: + reader.close() + gc.collect() + + def tb_get_extracted_frame_folders(self) -> list: + if not os.path.exists(self.extracted_frames_target_path): + self.message_manager.add_warning(f"Extracted frames directory not found: {self.extracted_frames_target_path}") + return [] + try: + folders = [ + d for d in os.listdir(self.extracted_frames_target_path) + if os.path.isdir(os.path.join(self.extracted_frames_target_path, d)) + ] + folders.sort() + return folders + except Exception as e: + self.message_manager.add_error(f"Error scanning for extracted frame folders: {e}") + return [] + + def tb_delete_extracted_frames_folder(self, folder_name_to_delete: str) -> bool: + if not folder_name_to_delete: + self.message_manager.add_warning("No folder selected for deletion.") + return False + + folder_path_to_delete = os.path.join(self.extracted_frames_target_path, folder_name_to_delete) + + if not os.path.exists(folder_path_to_delete) or not os.path.isdir(folder_path_to_delete): + self.message_manager.add_error(f"Folder not found or is not a directory: {folder_path_to_delete}") + return False + + try: + shutil.rmtree(folder_path_to_delete) + self.message_manager.add_success(f"Successfully deleted folder: {folder_name_to_delete}") + return True + except Exception as e: + self.message_manager.add_error(f"Error deleting folder '{folder_name_to_delete}': {e}") + self.message_manager.add_error(traceback.format_exc()) + return False + + def tb_reassemble_frames_to_video(self, frames_source, output_fps, output_base_name_override=None, progress=gr.Progress()): + if not frames_source: + self.message_manager.add_warning("No frames source (folder or files) provided for reassembly.") + return None + + try: + output_fps = int(output_fps) + if output_fps <= 0: + self.message_manager.add_error("Output FPS must be a positive number.") + return None + except ValueError: + self.message_manager.add_error("Invalid FPS value for reassembly.") + return None + + self.message_manager.add_message(f"Starting frame reassembly to video at {output_fps} FPS.") + + frame_info_list = [] + frames_data_prepared = False + + try: + # This logic now primarily handles a directory path string + if isinstance(frames_source, str) and os.path.isdir(frames_source): + self.message_manager.add_message(f"Processing frames from directory: {frames_source}") + # Use our existing function to get a sorted list of frame paths + sorted_frame_paths = self.tb_get_frames_from_folder(os.path.basename(frames_source)) + for full_path in sorted_frame_paths: + frame_info_list.append({ + 'original_like_filename': os.path.basename(full_path), + 'temp_path': full_path + }) + else: + self.message_manager.add_error("Invalid frames_source type for reassembly. Expected a directory path.") + return None + + if not frame_info_list: + self.message_manager.add_warning("No valid image files found in the provided source to reassemble.") + return None + + self.message_manager.add_message(f"Found {len(frame_info_list)} frames for reassembly.") + + output_file_basename = "reassembled_video" + if output_base_name_override and isinstance(output_base_name_override, str) and output_base_name_override.strip(): + sanitized_name = "".join(c if c.isalnum() or c in (' ', '_', '-') else '_' for c in output_base_name_override.strip()) + output_file_basename = Path(sanitized_name).stem + if not output_file_basename: output_file_basename = "reassembled_video" + self.message_manager.add_message(f"Using custom output video base name: {output_file_basename}") + + output_video_path = self._tb_generate_output_path( + input_material_name=output_file_basename, + suffix=f"{output_fps}fps_reassembled", + target_dir=self.reassembled_video_target_path, + ext=".mp4" + ) + + frames_data = [] + frames_data_prepared = True + + self.message_manager.add_message("Reading frame images (in sorted order)...") + + frame_iterator = frame_info_list + if frame_info_list and progress is not None and hasattr(progress, 'tqdm'): + frame_iterator = progress.tqdm(frame_info_list, desc="Reading frames") + + for frame_info in frame_iterator: + frame_actual_path = frame_info['temp_path'] + filename_for_log = frame_info['original_like_filename'] + try: + if not filename_for_log.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')): + self.message_manager.add_warning(f"Skipping non-standard image file: {filename_for_log}.") + continue + frames_data.append(imageio.imread(frame_actual_path)) + except Exception as e_read_frame: + self.message_manager.add_warning(f"Could not read frame ({filename_for_log}): {e_read_frame}. Skipping.") + + if not frames_data: + self.message_manager.add_error("No valid frames could be successfully read for reassembly.") + return None + + self.message_manager.add_message(f"Writing {len(frames_data)} frames to video: {output_video_path}") + imageio.mimwrite(output_video_path, frames_data, fps=output_fps, quality=VIDEO_QUALITY, macro_block_size=None) + + self.message_manager.add_success(f"Successfully reassembled {len(frames_data)} frames into: {output_video_path}") + return output_video_path + + except Exception as e: + self.message_manager.add_error(f"Error during frame reassembly: {e}") + self.message_manager.add_error(traceback.format_exc()) + if "Could not find a backend" in str(e) or "No such file or directory: 'ffmpeg'" in str(e).lower(): + self.message_manager.add_error("This might indicate an issue with FFmpeg backend for imageio. Ensure 'imageio-ffmpeg' is installed or FFmpeg is in PATH.") + return None + finally: + if frames_data_prepared and 'frames_data' in locals(): + del frames_data + gc.collect() + + def _tb_get_video_duration(self, video_path: str) -> str | None: + """Uses ffprobe to get the duration of a video file as a string.""" + if not self.has_ffprobe: + return None + try: + ffprobe_cmd = [ + self.ffprobe_exe, "-v", "error", "-show_entries", + "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", video_path + ] + result = subprocess.run(ffprobe_cmd, capture_output=True, text=True, check=True, errors='ignore') + return result.stdout.strip() + except Exception: + return None + + def tb_join_videos(self, video_paths: list, output_base_name_override=None, progress=gr.Progress()): + if not video_paths or len(video_paths) < 2: + self.message_manager.add_warning("Please select at least two videos to join.") + return None + if not self.has_ffmpeg: + self.message_manager.add_error("FFmpeg is required for joining videos. This operation cannot proceed.") + return None + + self.message_manager.add_message(f"🚀 Starting video join process for {len(video_paths)} videos...") + progress(0.1, desc="Analyzing input videos...") + + # --- 1. STANDARDIZE DIMENSIONS --- + # Get dimensions of the first video to use as the standard for all others. + first_video_dims = self._tb_get_video_dimensions(video_paths[0]) + if not all(first_video_dims): + self.message_manager.add_error("Could not determine dimensions of the first video. Cannot proceed.") + return None + target_w, target_h = first_video_dims + self.message_manager.add_message(f"Standardizing all videos to {target_w}x{target_h} for joining.") + + # --- 2. BUILD THE FFMPEG COMMAND --- + ffmpeg_cmd = [self.ffmpeg_exe, "-y", "-loglevel", "error"] + filter_complex_parts = [] + video_stream_labels = [] + audio_stream_labels = [] + + # Loop through each input video to prepare its streams. + for i, path in enumerate(video_paths): + ffmpeg_cmd.extend(["-i", str(Path(path).resolve())]) + + # --- VIDEO STREAM PREPARATION --- + video_label = f"v{i}" + # Scale video, pad to fit, set aspect ratio, and ensure standard pixel format. + filter_complex_parts.append( + f"[{i}:v:0]scale={target_w}:{target_h}:force_original_aspect_ratio=decrease,pad={target_w}:{target_h}:-1:-1:color=black,setsar=1,format=yuv420p[{video_label}]" + ) + video_stream_labels.append(f"[{video_label}]") + + # --- AUDIO STREAM PREPARATION --- + audio_label = f"a{i}" + if self._tb_has_audio_stream(path): + # If audio exists, standardize it to a common format. + filter_complex_parts.append( + f"[{i}:a:0]aformat=sample_fmts=fltp:sample_rates=44100:channel_layouts=stereo[{audio_label}]" + ) + else: + # If no audio, get the video's duration first. + duration = self._tb_get_video_duration(path) + if duration: + # Then, generate a silent audio track of that exact duration. + self.message_manager.add_message(f"'{Path(path).name}' has no audio. Generating silent track of {float(duration):.2f}s.", "INFO") + filter_complex_parts.append( + f"anullsrc=channel_layout=stereo:sample_rate=44100,atrim=duration={duration}[{audio_label}]" + ) + else: + # If we can't get duration, we can't create a silent track, so we must skip it. + self.message_manager.add_warning(f"Could not get duration for '{Path(path).name}' to generate silent audio. This track's audio will be skipped.") + continue + audio_stream_labels.append(f"[{audio_label}]") + + # --- 3. CONCATENATE THE STREAMS --- + # Join all the prepared video and audio streams together into final output streams. + filter_complex_parts.append(f"{''.join(video_stream_labels)}concat=n={len(video_paths)}:v=1:a=0[outv]") + + # Only add the audio concat filter if we successfully prepared audio streams. + if audio_stream_labels: + filter_complex_parts.append(f"{''.join(audio_stream_labels)}concat=n={len(audio_stream_labels)}:v=0:a=1[outa]") + + final_filter_complex = ";".join(filter_complex_parts) + ffmpeg_cmd.extend(["-filter_complex", final_filter_complex]) + + # --- 4. MAP AND ENCODE THE FINAL VIDEO --- + # Map the final concatenated video stream to the output. + ffmpeg_cmd.extend(["-map", "[outv]"]) + # If we have a final audio stream, map that too. + if audio_stream_labels: + ffmpeg_cmd.extend(["-map", "[outa]"]) + + # Determine the output filename. + if output_base_name_override and isinstance(output_base_name_override, str) and output_base_name_override.strip(): + sanitized_name = "".join(c for c in output_base_name_override.strip() if c.isalnum() or c in (' ', '_', '-')).strip() + base_name_to_use = Path(sanitized_name).stem if sanitized_name else Path(video_paths[0]).stem + else: + base_name_to_use = Path(video_paths[0]).stem + + output_path = self._tb_generate_output_path( + base_name_to_use, + suffix=f"joined_{len(video_paths)}_videos", + target_dir=self.toolbox_video_output_dir + ) + + # Set standard, high-compatibility encoding options. + ffmpeg_cmd.extend([ + "-c:v", "libx264", "-preset", "medium", "-crf", "20", + "-c:a", "aac", "-b:a", "192k", output_path + ]) + + # --- 5. EXECUTE THE COMMAND --- + try: + self.message_manager.add_message("Running FFmpeg to join videos. This may take a while...") + progress(0.5, desc=f"Joining {len(video_paths)} videos...") + + subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True, errors='ignore') + + progress(1.0, desc="Join complete.") + self.message_manager.add_success(f"✅ Videos successfully joined! Output: {output_path}") + return output_path + + except subprocess.CalledProcessError as e_join: + self._tb_log_ffmpeg_error(e_join, "video joining") + return None + except Exception as e: + self.message_manager.add_error(f"An unexpected error occurred during video joining: {e}") + self.message_manager.add_error(traceback.format_exc()) + return None + finally: + gc.collect() + + def _tb_clean_filename(self, filename): + filename = re.sub(r'_\d{6}_\d{6}', '', filename) # Example timestamp pattern + filename = re.sub(r'_\d{6}_\d{4}', '', filename) # Another example + return filename.strip('_') + + def tb_export_video(self, video_path: str, export_format: str, quality_slider: int, max_width: int, + output_base_name_override=None, progress=gr.Progress()): + if not video_path: + self.message_manager.add_warning("No input video for exporting.") + return None + if not self.has_ffmpeg: + self.message_manager.add_error("FFmpeg is required for exporting. This operation cannot proceed.") + return None + + self.message_manager.add_message(f"🚀 Starting export to {export_format.upper()}...") + progress(0, desc=f"Preparing to export to {export_format.upper()}...") + + resolved_video_path = str(Path(video_path).resolve()) + + # --- Base FFmpeg Command --- + ffmpeg_cmd = [self.ffmpeg_exe, "-y", "-loglevel", "error", "-i", resolved_video_path] + + # --- Video Filters (Resizing) --- + vf_parts = [] + # The scale filter resizes while maintaining aspect ratio. '-2' ensures the height is an even number for codec compatibility. + vf_parts.append(f"scale={max_width}:-2") + + # --- Format-Specific Settings --- + ext = f".{export_format.lower()}" + + if export_format == "MP4": + # CRF (Constant Rate Factor) is the quality setting for x264. Lower is higher quality. + # We map our 0-100 slider to a good CRF range (e.g., 28 (low) to 18 (high)). + crf_value = int(28 - (quality_slider / 100) * 10) + self.message_manager.add_message(f"Exporting MP4 with CRF: {crf_value} (Quality: {quality_slider}%)") + ffmpeg_cmd.extend(["-c:v", "libx264", "-crf", str(crf_value), "-preset", "medium"]) + ffmpeg_cmd.extend(["-c:a", "aac", "-b:a", "128k"]) # Keep audio, but compress it + + elif export_format == "WebM": + # Similar to MP4, but for the VP9 codec. A good CRF range is ~35 (low) to 25 (high). + crf_value = int(35 - (quality_slider / 100) * 10) + self.message_manager.add_message(f"Exporting WebM with CRF: {crf_value} (Quality: {quality_slider}%)") + ffmpeg_cmd.extend(["-c:v", "libvpx-vp9", "-crf", str(crf_value), "-b:v", "0"]) + ffmpeg_cmd.extend(["-c:a", "libopus", "-b:a", "96k"]) # Use Opus for WebM audio + + elif export_format == "GIF": + # High-quality GIF generation is a two-pass process. + self.message_manager.add_message("Generating high-quality GIF (2-pass)...") + # Pass 1: Generate a color palette. + palette_path = os.path.join(self._base_temp_output_dir, f"palette_{Path(video_path).stem}.png") + vf_parts.append("split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse") + ffmpeg_cmd.extend(["-an"]) # No audio in GIFs + + if vf_parts: + ffmpeg_cmd.extend(["-vf", ",".join(vf_parts)]) + + # --- Output Path --- + if output_base_name_override and isinstance(output_base_name_override, str) and output_base_name_override.strip(): + sanitized_name = "".join(c for c in output_base_name_override.strip() if c.isalnum() or c in (' ', '_', '-')).strip() + base_name_to_use = Path(sanitized_name).stem if sanitized_name else Path(video_path).stem + else: + base_name_to_use = Path(video_path).stem + + # --- SPECIAL HANDLING FOR GIF OUTPUT PATH --- + if export_format == "GIF": + # GIFs are always saved to the permanent directory to avoid being lost + # by Gradio's re-encoding for the video player preview. + target_dir_for_export = self._base_permanent_save_dir + self.message_manager.add_message("GIF export detected. Output will be forced to the permanent 'saved_videos' folder, ignoring Autosave setting.", "INFO") + else: + # For MP4/WebM, respect the current autosave setting + target_dir_for_export = self.toolbox_video_output_dir + + output_path = self._tb_generate_output_path( + base_name_to_use, + suffix=f"exported_{quality_slider}q_{max_width}w", + target_dir=target_dir_for_export, + ext=ext + ) + ffmpeg_cmd.append(output_path) + + # --- Execute --- + try: + progress(0.5, desc=f"Encoding to {export_format.upper()}...") + subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True, errors='ignore') + progress(1.0, desc="Export complete!") + + # Add specific messaging for GIF vs other formats + if export_format == "GIF": + self.message_manager.add_success(f"✅ GIF successfully created and saved to: {output_path}") + self.message_manager.add_warning("⚠️ Note: The video player shows a re-encoded MP4 for preview. Your original GIF is in the output folder.") + else: + self.message_manager.add_success(f"✅ Successfully exported to {export_format.upper()}! Output: {output_path}") + + return output_path + + except subprocess.CalledProcessError as e: + self._tb_log_ffmpeg_error(e, f"export to {export_format.upper()}") + return None + except Exception as e: + self.message_manager.add_error(f"An unexpected error occurred during export: {e}") + self.message_manager.add_error(traceback.format_exc()) + return None + finally: + gc.collect() + + def _tb_generate_output_path(self, input_material_name, suffix, target_dir, ext=".mp4"): + base_name = Path(input_material_name).stem + if not base_name: base_name = "untitled_video" + cleaned_name = self._tb_clean_filename(base_name) + timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + filename = f"{cleaned_name}_{suffix}_{timestamp}{ext}" + return os.path.join(target_dir, filename) + + def _tb_generate_output_folder_path(self, input_video_path, suffix): + base_name = Path(input_video_path).stem + if not base_name: base_name = "untitled_video_frames" + cleaned_name = self._tb_clean_filename(base_name) + timestamp = datetime.now().strftime("%y%m%d_%H%M%S") + folder_name = f"{cleaned_name}_{suffix}_{timestamp}" + return os.path.join(self.extracted_frames_target_path, folder_name) + + def tb_copy_video_to_permanent_storage(self, temp_video_path): + if not temp_video_path or not os.path.exists(temp_video_path): + self.message_manager.add_error("No video file provided or file does not exist to save.") + return temp_video_path + + try: + video_filename = Path(temp_video_path).name + permanent_video_path = os.path.join(self.toolbox_permanent_save_dir, video_filename) + os.makedirs(self.toolbox_permanent_save_dir, exist_ok=True) + self.message_manager.add_message(f"Copying '{video_filename}' to permanent storage: '{permanent_video_path}'") + shutil.copy2(temp_video_path, permanent_video_path) + self.message_manager.add_success(f"Video saved to: {permanent_video_path}") + return permanent_video_path + except Exception as e: + self.message_manager.add_error(f"Error saving video to permanent storage: {e}") + self.message_manager.add_error(traceback.format_exc()) + return temp_video_path + + def tb_analyze_video_input(self, video_path): + if video_path is None: + self.message_manager.add_warning("No video provided for analysis.") + return "Please upload a video." + + resolved_video_path = str(Path(video_path).resolve()) + analysis_report_lines = [] # Use a list to build the report string + + file_size_bytes = 0 + file_size_display = "N/A" + try: + if os.path.exists(resolved_video_path): + file_size_bytes = os.path.getsize(resolved_video_path) + if file_size_bytes < 1024: + file_size_display = f"{file_size_bytes} B" + elif file_size_bytes < 1024**2: + file_size_display = f"{file_size_bytes/1024:.2f} KB" + elif file_size_bytes < 1024**3: + file_size_display = f"{file_size_bytes/1024**2:.2f} MB" + else: + file_size_display = f"{file_size_bytes/1024**3:.2f} GB" + except Exception as e: + self.message_manager.add_warning(f"Could not get file size: {e}") + + # Variables to hold parsed info, initialized to defaults + video_width, video_height = 0, 0 + num_frames_value = None # For the upscale warning + duration_display, fps_display, resolution_display, nframes_display, has_audio_str = "N/A", "N/A", "N/A", "N/A", "No" + analysis_source = "imageio" # Default analysis source + + if self.has_ffprobe: + self.message_manager.add_message(f"Analyzing video with ffprobe: {os.path.basename(video_path)}") + try: + probe_cmd = [ + self.ffprobe_exe, "-v", "error", "-show_format", "-show_streams", + "-of", "json", resolved_video_path + ] + result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True, errors='ignore') + probe_data = json.loads(result.stdout) + + video_stream = next((s for s in probe_data.get("streams", []) if s.get("codec_type") == "video"), None) + audio_stream = next((s for s in probe_data.get("streams", []) if s.get("codec_type") == "audio"), None) + + if not video_stream: + self.message_manager.add_error("No video stream found in the file (ffprobe).") + + else: + analysis_source = "ffprobe" + duration_str = probe_data.get("format", {}).get("duration", "0") + duration = float(duration_str) if duration_str and duration_str.replace('.', '', 1).isdigit() else 0.0 + duration_display = f"{duration:.2f} seconds" + + r_frame_rate_str = video_stream.get("r_frame_rate", "0/0") + avg_frame_rate_str = video_stream.get("avg_frame_rate", "0/0") + calculated_fps = 0.0 + + def parse_fps(fps_s): + if isinstance(fps_s, (int, float)): return float(fps_s) + if isinstance(fps_s, str) and "/" in fps_s: + try: num, den = map(float, fps_s.split('/')); return num / den if den != 0 else 0.0 + except ValueError: return 0.0 + try: return float(fps_s) + except ValueError: return 0.0 + + r_fps_val = parse_fps(r_frame_rate_str); avg_fps_val = parse_fps(avg_frame_rate_str) + + if r_fps_val > 0: calculated_fps = r_fps_val; fps_display = f"{r_fps_val:.2f} FPS" + if avg_fps_val > 0 and abs(r_fps_val - avg_fps_val) > 0.01 : # Only show average if meaningfully different + calculated_fps = avg_fps_val # Prefer average if it's different and valid + fps_display = f"{avg_fps_val:.2f} FPS (Avg, r: {r_fps_val:.2f})" + elif avg_fps_val > 0 and r_fps_val <=0: + calculated_fps = avg_fps_val; fps_display = f"{avg_fps_val:.2f} FPS (Average)" + + video_width = video_stream.get("width", 0) + video_height = video_stream.get("height", 0) + resolution_display = f"{video_width}x{video_height}" if video_width and video_height else "N/A" + + nframes_str_probe = video_stream.get("nb_frames") + if nframes_str_probe and nframes_str_probe.isdigit(): + num_frames_value = int(nframes_str_probe) + nframes_display = str(num_frames_value) + elif duration > 0 and calculated_fps > 0: + num_frames_value = int(duration * calculated_fps) + nframes_display = f"{num_frames_value} (Calculated)" + + if audio_stream: + has_audio_str = (f"Yes (Codec: {audio_stream.get('codec_name', 'N/A')}, " + f"Channels: {audio_stream.get('channels', 'N/A')}, " + f"Rate: {audio_stream.get('sample_rate', 'N/A')} Hz)") + self.message_manager.add_success("Video analysis complete (using ffprobe).") + + except (subprocess.CalledProcessError, json.JSONDecodeError, Exception) as e_ffprobe: + self.message_manager.add_warning(f"ffprobe analysis failed ({type(e_ffprobe).__name__}). Trying imageio fallback.") + if isinstance(e_ffprobe, subprocess.CalledProcessError): + self._tb_log_ffmpeg_error(e_ffprobe, "video analysis with ffprobe") + analysis_source = "imageio" # Ensure fallback if ffprobe fails midway + + if analysis_source == "imageio": # Either ffprobe not available, or it failed + self.message_manager.add_message(f"Analyzing video with imageio: {os.path.basename(video_path)}") + reader = None + try: + reader = imageio.get_reader(resolved_video_path) + meta = reader.get_meta_data() + + duration_imgio_val = meta.get('duration') + duration_display = f"{float(duration_imgio_val):.2f} seconds" if duration_imgio_val is not None else "N/A" + + fps_val_imgio = meta.get('fps') + fps_display = f"{float(fps_val_imgio):.2f} FPS" if fps_val_imgio is not None else "N/A" + + size_imgio = meta.get('size') + if isinstance(size_imgio, tuple) and len(size_imgio) == 2: + video_width, video_height = int(size_imgio[0]), int(size_imgio[1]) + resolution_display = f"{video_width}x{video_height}" + else: + resolution_display = "N/A" + + nframes_val_imgio_meta = meta.get('nframes') + if nframes_val_imgio_meta not in [float('inf'), "N/A", None] and isinstance(nframes_val_imgio_meta, (int,float)): + num_frames_value = int(nframes_val_imgio_meta) + nframes_display = str(num_frames_value) + elif hasattr(reader, 'count_frames'): + try: + nframes_val_imgio_count = reader.count_frames() + if nframes_val_imgio_count != float('inf'): + num_frames_value = int(nframes_val_imgio_count) + nframes_display = f"{num_frames_value} (Counted)" + else: nframes_display = "Unknown (Stream or very long)" + except Exception: nframes_display = "Unknown (Frame count failed)" + + has_audio_str = "(Audio info not available via imageio)" + self.message_manager.add_success("Video analysis complete (using imageio).") + except Exception as e_imgio: + self.message_manager.add_error(f"Error analyzing video with imageio: {e_imgio}") + import traceback + self.message_manager.add_error(traceback.format_exc()) + return f"Error analyzing video: Both ffprobe (if attempted) and imageio failed." + finally: + if reader: reader.close() + + # --- Construct Main Analysis Report --- + analysis_report_lines.append(f"Video Analysis ({analysis_source}):") + analysis_report_lines.append(f"File: {os.path.basename(video_path)}") + analysis_report_lines.append("------------------------------------") + analysis_report_lines.append(f"File Size: {file_size_display}") + analysis_report_lines.append(f"Duration: {duration_display}") + analysis_report_lines.append(f"Frame Rate: {fps_display}") + analysis_report_lines.append(f"Resolution: {resolution_display}") + analysis_report_lines.append(f"Frames: {nframes_display}") + analysis_report_lines.append(f"Audio: {has_audio_str}") + analysis_report_lines.append(f"Source: {video_path}") + + # --- Append UPSCALE ADVISORY Conditionally --- + if video_width > 0 and video_height > 0: # Ensure we have dimensions + HD_WIDTH_THRESHOLD = 1920 + FOUR_K_WIDTH_THRESHOLD = 3800 + + is_hd_or_larger = (video_width >= HD_WIDTH_THRESHOLD or video_height >= (HD_WIDTH_THRESHOLD * 9/16 * 0.95)) # Adjusted height for aspect ratios + is_4k_or_larger = (video_width >= FOUR_K_WIDTH_THRESHOLD or video_height >= (FOUR_K_WIDTH_THRESHOLD * 9/16 * 0.95)) + + upscale_warnings = [] + if is_4k_or_larger: + upscale_warnings.append( + "This video is 4K resolution or higher. Upscaling (e.g., to 8K+) will be very " + "slow, memory-intensive, and may cause issues. Proceed with caution." + ) + elif is_hd_or_larger: + upscale_warnings.append( + "This video is HD or larger. Upscaling (e.g., to 4K+) will be resource-intensive " + "and slow. Ensure your system is prepared." + ) + + if num_frames_value and num_frames_value > 900: # e.g., > 30 seconds at 30fps + upscale_warnings.append( + f"With {num_frames_value} frames, upscaling will also be very time-consuming." + ) + + if upscale_warnings: + analysis_report_lines.append("\n--- UPSCALE ADVISORY ---") + for warning_msg in upscale_warnings: + analysis_report_lines.append(f"⚠️ {warning_msg}") + # analysis_report_lines.append("------------------------") # Optional closing separator + + return "\n".join(analysis_report_lines) + + + def _tb_has_audio_stream(self, video_path_to_check): + if not self.has_ffprobe: # Critical check + self.message_manager.add_warning( + "FFprobe not available. Cannot reliably determine if video has audio. " + "Assuming no audio for operations requiring this check. " + "Install FFmpeg with ffprobe for full audio support." + ) + return False + try: + resolved_path = str(Path(video_path_to_check).resolve()) + ffprobe_cmd = [ + self.ffprobe_exe, "-v", "error", "-select_streams", "a:0", + "-show_entries", "stream=codec_type", "-of", "csv=p=0", resolved_path + ] + # check=False because a non-zero return often means no audio stream, which is a valid outcome here. + audio_check_result = subprocess.run(ffprobe_cmd, capture_output=True, text=True, check=False, errors='ignore') + + if audio_check_result.returncode == 0 and "audio" in audio_check_result.stdout.strip().lower(): + return True + else: + # Optionally log if ffprobe ran but found no audio, or if it errored for other reasons + # if audio_check_result.returncode != 0 and audio_check_result.stderr: + # self.message_manager.add_message(f"FFprobe check for audio stream in {os.path.basename(video_path_to_check)} completed. Stderr: {audio_check_result.stderr.strip()}", "DEBUG") + return False + except FileNotFoundError: + self.message_manager.add_warning("FFprobe executable not found during audio stream check (should have been caught by self.has_ffprobe). Assuming no audio.") + return False # Should ideally not happen if self.has_ffprobe is true and self.ffprobe_exe is set + except Exception as e: + self.message_manager.add_warning(f"Error checking for audio stream in {os.path.basename(video_path_to_check)}: {e}. Assuming no audio.") + return False + + def tb_process_frames(self, video_path, target_fps_mode, speed_factor, use_streaming: bool, progress=gr.Progress()): + if video_path is None: self.message_manager.add_warning("No input video for frame processing."); return None + + final_output_path = None + reader = None + writer = None + video_stream_output_path = None + + try: + interpolation_factor = 1 + if "2x" in target_fps_mode: interpolation_factor = 2 + elif "4x" in target_fps_mode: interpolation_factor = 4 + should_interpolate = interpolation_factor > 1 + + self.message_manager.add_message(f"Starting frame processing: FPS Mode: {target_fps_mode}, Speed: {speed_factor}x") + progress(0, desc="Initializing...") + + resolved_video_path = str(Path(video_path).resolve()) + reader = imageio.get_reader(resolved_video_path) + meta_data = reader.get_meta_data() + original_fps = meta_data.get('fps', 30.0) + output_fps = original_fps * interpolation_factor + + self.message_manager.add_message( + f"User selected {'Streaming (low memory)' if use_streaming else 'In-Memory (fast)'} mode for frame processing." + ) + if use_streaming and speed_factor != 1.0: + self.message_manager.add_warning("Speed factor is not applied in Streaming Interpolation mode. Processing at 1.0x speed.") + speed_factor = 1.0 + + op_suffix_parts = [] + if speed_factor != 1.0: op_suffix_parts.append(f"speed{speed_factor:.2f}x".replace('.',',')) + if should_interpolate: op_suffix_parts.append(f"RIFE{interpolation_factor}x") + op_suffix = "_".join(op_suffix_parts) if op_suffix_parts else "processed" + temp_video_suffix = f"{op_suffix}_temp_video" + video_stream_output_path = self._tb_generate_output_path(resolved_video_path, suffix=temp_video_suffix, target_dir=self.toolbox_video_output_dir) + final_muxed_output_path = video_stream_output_path.replace("_temp_video", "") + + # --- PROCESSING BLOCK --- + if use_streaming: + # --- STREAMING (LOW MEMORY) PATH - WITH FULL 4X LOGIC --- + if not should_interpolate: + self.message_manager.add_warning("Streaming mode selected but no interpolation chosen. Writing video without changes.") + writer = imageio.get_writer(video_stream_output_path, fps=original_fps, quality=VIDEO_QUALITY, macro_block_size=None) + for frame in reader: + writer.append_data(frame) + else: + writer = imageio.get_writer(video_stream_output_path, fps=output_fps, quality=VIDEO_QUALITY, macro_block_size=None) + self.message_manager.add_message(f"Attempting to load RIFE model for {interpolation_factor}x interpolation...") + if not self.rife_handler._ensure_model_downloaded_and_loaded(): + self.message_manager.add_error("RIFE model could not be loaded. Aborting."); return None + + n_frames = self._tb_get_video_frame_count(resolved_video_path) + if n_frames is None: + self.message_manager.add_error("Cannot determine video length for streaming progress. Aborting.") + return None + + num_passes = int(math.log2(interpolation_factor)) + desc = f"RIFE Pass 1/{num_passes} (Streaming)" + self.message_manager.add_message(desc) + + try: + frame1_np = next(iter(reader)) + except StopIteration: + self.message_manager.add_warning("Video has no frames."); return None + + # This list will only be used if we are doing a 4x (2-pass) interpolation + intermediate_frames_for_pass2 = [frame1_np] if num_passes > 1 else None + + # Loop for the first pass (2x) + for i, frame2_np in enumerate(reader, 1): + progress(i / (n_frames - 1), desc=desc) + + # For 2x mode (num_passes == 1), we write directly to the file. + if num_passes == 1: + writer.append_data(frame1_np) + + middle_frame_np = self.rife_handler.interpolate_between_frames(frame1_np, frame2_np) + + if middle_frame_np is not None: + if num_passes == 1: + writer.append_data(middle_frame_np) + # For 4x mode, we collect the 2x results in a list. + if intermediate_frames_for_pass2 is not None: + intermediate_frames_for_pass2.append(middle_frame_np) + else: # On failure, duplicate previous frame + if num_passes == 1: + writer.append_data(frame1_np) + if intermediate_frames_for_pass2 is not None: + intermediate_frames_for_pass2.append(frame1_np) + + # Add the "end" frame of the pair to our intermediate list for the next pass + if intermediate_frames_for_pass2 is not None: + intermediate_frames_for_pass2.append(frame2_np) + + frame1_np = frame2_np + + if num_passes == 1: + writer.append_data(frame1_np) + + if num_passes > 1 and intermediate_frames_for_pass2: + self.message_manager.add_message(f"RIFE Pass 2/{num_passes}: Interpolating 2x frames (in-memory)...") + + pass2_iterator = progress.tqdm( + range(len(intermediate_frames_for_pass2) - 1), + desc=f"RIFE Pass 2/{num_passes}" + ) + + # Loop through the 2x frames to create 4x frames, mirroring the IN-MEMORY logic. + for i in pass2_iterator: + p2_frame1 = intermediate_frames_for_pass2[i] + p2_frame2 = intermediate_frames_for_pass2[i+1] + + # Write the "start" frame of the pair + writer.append_data(p2_frame1) + + # Interpolate and write the middle frame + p2_middle = self.rife_handler.interpolate_between_frames(p2_frame1, p2_frame2) + + if p2_middle is not None: + writer.append_data(p2_middle) + else: # On failure, duplicate + writer.append_data(p2_frame1) + + # After the loop, write the very last frame of the entire list + writer.append_data(intermediate_frames_for_pass2[-1]) + else: + # --- IN-MEMORY (FAST) PATH --- + self.message_manager.add_message("Reading all video frames into memory...") + video_frames = [frame for frame in reader] + + processed_frames = video_frames + if speed_factor != 1.0: + self.message_manager.add_message(f"Adjusting speed by {speed_factor}x (in-memory)...") + if speed_factor > 1.0: + indices = np.arange(0, len(video_frames), speed_factor).astype(int) + processed_frames = [video_frames[i] for i in indices if i < len(video_frames)] + else: + new_len = int(len(video_frames) / speed_factor) + indices = np.linspace(0, len(video_frames) - 1, new_len).astype(int) + processed_frames = [video_frames[i] for i in indices] + + if should_interpolate and len(processed_frames) > 1: + self.message_manager.add_message(f"Loading RIFE for {interpolation_factor}x interpolation (in-memory)...") + if not self.rife_handler._ensure_model_downloaded_and_loaded(): + self.message_manager.add_error("RIFE model could not be loaded."); return None + + num_passes = int(math.log2(interpolation_factor)) + for p in range(num_passes): + self.message_manager.add_message(f"RIFE Pass {p+1}/{num_passes} (in-memory)...") + interpolated_this_pass = [] + frame_iterator = progress.tqdm(range(len(processed_frames) - 1), desc=f"RIFE Pass {p+1}/{num_passes}") + for i in frame_iterator: + interpolated_this_pass.append(processed_frames[i]) + middle_frame = self.rife_handler.interpolate_between_frames(processed_frames[i], processed_frames[i+1]) + interpolated_this_pass.append(middle_frame if middle_frame is not None else processed_frames[i]) + interpolated_this_pass.append(processed_frames[-1]) + processed_frames = interpolated_this_pass + + self.message_manager.add_message(f"Writing {len(processed_frames)} frames to file...") + writer = imageio.get_writer(video_stream_output_path, fps=output_fps, quality=VIDEO_QUALITY, macro_block_size=None) + for frame in progress.tqdm(processed_frames, desc="Writing frames"): + writer.append_data(frame) + + # --- Universal Teardown & Muxing --- + if writer: writer.close() + if reader: reader.close() + writer, reader = None, None + + final_output_path = final_muxed_output_path + can_process_audio = self.has_ffmpeg + original_video_has_audio = self._tb_has_audio_stream(resolved_video_path) if can_process_audio else False + + if can_process_audio and original_video_has_audio: + self.message_manager.add_message("Original video has audio. Processing audio with FFmpeg...") + progress(0.9, desc="Processing audio...") + ffmpeg_mux_cmd = [self.ffmpeg_exe, "-y", "-loglevel", "error", "-i", video_stream_output_path] + audio_filters = [] + if speed_factor != 1.0: + if 0.5 <= speed_factor <= 100.0: + audio_filters.append(f"atempo={speed_factor:.4f}") + elif speed_factor < 0.5: # Needs multiple 0.5 steps + num_half_steps = int(np.ceil(np.log(speed_factor) / np.log(0.5))) + for _ in range(num_half_steps): audio_filters.append("atempo=0.5") + final_factor = speed_factor / (0.5**num_half_steps) + if abs(final_factor - 1.0) > 1e-4 and 0.5 <= final_factor <= 100.0: # Add final adjustment if needed + audio_filters.append(f"atempo={final_factor:.4f}") + elif speed_factor > 100.0: # Needs multiple 2.0 (or higher, like 100.0) steps + num_double_steps = int(np.ceil(np.log(speed_factor / 100.0) / np.log(2.0))) # Example for steps of 2 after 100 + audio_filters.append("atempo=100.0") # Max one step + remaining_factor = speed_factor / 100.0 + if abs(remaining_factor - 1.0) > 1e-4 and 0.5 <= remaining_factor <= 100.0: + audio_filters.append(f"atempo={remaining_factor:.4f}") + + + self.message_manager.add_message(f"Applying audio speed adjustment with atempo: {','.join(audio_filters) if audio_filters else 'None (speed_factor out of simple atempo range or 1.0)'}") + + ffmpeg_mux_cmd.extend(["-i", resolved_video_path]) # Input for audio + ffmpeg_mux_cmd.extend(["-c:v", "copy"]) + + if audio_filters: + ffmpeg_mux_cmd.extend(["-filter:a", ",".join(audio_filters)]) + # Always re-encode audio to AAC for MP4 compatibility, even if no speed change, + # as original audio might not be AAC. + ffmpeg_mux_cmd.extend(["-c:a", "aac", "-b:a", "192k"]) + ffmpeg_mux_cmd.extend(["-map", "0:v:0", "-map", "1:a:0?", "-shortest", final_muxed_output_path]) + + try: + subprocess.run(ffmpeg_mux_cmd, check=True, capture_output=True, text=True) + self.message_manager.add_success(f"Video saved with processed audio: {final_muxed_output_path}") + except subprocess.CalledProcessError as e_mux: + self._tb_log_ffmpeg_error(e_mux, "audio processing/muxing") + if os.path.exists(final_muxed_output_path): os.remove(final_muxed_output_path) + os.rename(video_stream_output_path, final_muxed_output_path) + else: + if original_video_has_audio and not can_process_audio: + self.message_manager.add_warning("Original video has audio, but FFmpeg is not available to process it. Output will be silent.") + if os.path.exists(final_muxed_output_path) and final_muxed_output_path != video_stream_output_path: + os.remove(final_muxed_output_path) + os.rename(video_stream_output_path, final_muxed_output_path) + + if os.path.exists(video_stream_output_path) and video_stream_output_path != final_muxed_output_path: + try: os.remove(video_stream_output_path) + except Exception as e_clean: self.message_manager.add_warning(f"Could not remove temp video file {video_stream_output_path}: {e_clean}") + + progress(1.0, desc="Complete.") + self.message_manager.add_success(f"Frame processing complete: {final_output_path}") + return final_output_path + + except Exception as e: + self.message_manager.add_error(f"Error during frame processing: {e}") + import traceback; self.message_manager.add_error(traceback.format_exc()) + progress(1.0, desc="Error.") + return None + finally: + if reader and not reader.closed: reader.close() + if writer and not writer.closed: writer.close() + if self.rife_handler: self.rife_handler.unload_model() + devicetorch.empty_cache(torch); gc.collect() + + def tb_create_loop(self, video_path, loop_type, num_loops, progress=gr.Progress()): + if video_path is None: self.message_manager.add_warning("No input video for loop creation."); return None + if not self.has_ffmpeg: # FFmpeg is essential for this function's stream_loop and complex filter + self.message_manager.add_error("FFmpeg is required for creating video loops. This operation cannot proceed.") + return video_path # Return original video path + if loop_type == "none": self.message_manager.add_message("Loop type 'none'. No action."); return video_path + + progress(0, desc="Initializing loop creation...") + resolved_video_path = str(Path(video_path).resolve()) + output_path = self._tb_generate_output_path( + resolved_video_path, + suffix=f"{loop_type}_{num_loops}x", + target_dir=self.toolbox_video_output_dir + ) + + self.message_manager.add_message(f"Creating {loop_type} ({num_loops}x) for {os.path.basename(resolved_video_path)}...") + + ping_pong_unit_path = None + original_video_has_audio = self._tb_has_audio_stream(resolved_video_path) # Check once + + try: + progress(0.2, desc=f"Preparing {loop_type} loop...") + if loop_type == "ping-pong": + ping_pong_unit_path = self._tb_generate_output_path( + resolved_video_path, + suffix="pingpong_unit_temp", + target_dir=self.toolbox_video_output_dir + ) + # Create video-only ping-pong unit first + ffmpeg_pp_unit_cmd = [ + self.ffmpeg_exe, "-y", "-loglevel", "error", + "-i", resolved_video_path, + "-vf", "split[main][tmp];[tmp]reverse[rev];[main][rev]concat=n=2:v=1:a=0", # Video only + "-an", ping_pong_unit_path + ] + subprocess.run(ffmpeg_pp_unit_cmd, check=True, capture_output=True, text=True) + self.message_manager.add_message(f"Created ping-pong unit (video-only): {ping_pong_unit_path}") + + ffmpeg_cmd = [ + self.ffmpeg_exe, "-y", "-loglevel", "error", + "-stream_loop", str(num_loops - 1), # Loop the video unit N-1 times (total N plays) + "-i", ping_pong_unit_path, + "-c:v", "copy" + ] + if original_video_has_audio: + self.message_manager.add_message("Original video has audio. Will loop audio for ping-pong.") + # Audio duration needs to match 2T * num_loops. FFmpeg aloop is complex. + # A simpler approach for ping-pong audio might be to create a 2T audio segment (original + reversed original) + # then loop that audio segment N times and mux with the N-times looped ping-pong video. + # Current FFmpeg command for aloop: + audio_loop_count_for_ffmpeg = (num_loops * 2) -1 # Total plays of original audio needed, minus one for initial play by -i + ffmpeg_cmd.extend([ + "-i", resolved_video_path, # Original video for audio source + "-filter_complex", f"[1:a]areverse[areva];[1:a][areva]concat=n=2:v=0:a=1[ppa];[ppa]aloop=loop={num_loops-1}:size=2147483647[a_looped]", + "-map", "0:v:0", "-map", "[a_looped]", + "-c:a", "aac", "-b:a", "192k", "-shortest" + ]) + else: + self.message_manager.add_message("No audio in original or detection issue. Creating video-only ping-pong loop.") + ffmpeg_cmd.extend(["-an"]) + ffmpeg_cmd.append(output_path) + + else: # Regular 'loop' + + # Hi Gemini/Claude! + # The 'num_loops' parameter comes from the UI slider labeled "Number of Loops/Repeats". + # The user's intent is that this slider value (num_loops) represents the number of + # *additional* times the video should be repeated after its initial playthrough. + # For example, if num_loops = 1 (from slider), the video should play twice (original + 1 repeat). + # + # FFmpeg's -stream_loop option takes a value (let's call it X_ffmpeg), + # meaning the input is looped X_ffmpeg times *in addition* to the first play. + # So, X_ffmpeg should be equal to the slider value 'num_loops'. + + ffmpeg_stream_loop_value = num_loops + + # Ensure ffmpeg_stream_loop_value is non-negative. + # Given the UI slider minimum is typically 1, num_loops should always be >= 1. + # This check is for robustness if the input num_loops could ever be less than 0 + # (e.g., if UI constraints change or input comes from elsewhere). + if ffmpeg_stream_loop_value < 0: + ffmpeg_stream_loop_value = 0 # Should ideally not be hit if slider min is 1. + + # Total plays will be the original play + ffmpeg_stream_loop_value additional plays. + total_plays = ffmpeg_stream_loop_value + 1 + self.message_manager.add_message( + f"Regular loop: original video + {ffmpeg_stream_loop_value} additional repeat(s). Total {total_plays} plays." + ) + + ffmpeg_cmd = [ + self.ffmpeg_exe, "-y", "-loglevel", "error", + "-stream_loop", str(ffmpeg_stream_loop_value), # This now uses num_loops directly + "-i", resolved_video_path, + "-c:v", "copy" + ] + if original_video_has_audio: + self.message_manager.add_message("Original video has audio. Re-encoding to AAC for looped MP4 (if not already AAC).") + ffmpeg_cmd.extend(["-c:a", "aac", "-b:a", "192k", "-map", "0:v:0", "-map", "0:a:0?"]) + else: + self.message_manager.add_message("No audio in original or detection issue. Looped video will be silent.") + ffmpeg_cmd.extend(["-an", "-map", "0:v:0"]) + ffmpeg_cmd.append(output_path) + + self.message_manager.add_message(f"Processing video {loop_type} with FFmpeg...") + progress(0.5, desc=f"Running FFmpeg for {loop_type}...") + subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True, errors='ignore') + + progress(1.0, desc=f"{loop_type.capitalize()} loop created successfully.") + self.message_manager.add_success(f"Loop creation complete: {output_path}") + return output_path + except subprocess.CalledProcessError as e_loop: + self._tb_log_ffmpeg_error(e_loop, f"{loop_type} creation") + progress(1.0, desc=f"Error creating {loop_type}.") + return None + except Exception as e: + self.message_manager.add_error(f"Error creating loop: {e}") + import traceback; self.message_manager.add_error(traceback.format_exc()) + progress(1.0, desc="Error creating loop.") + return None + finally: + if ping_pong_unit_path and os.path.exists(ping_pong_unit_path): + try: os.remove(ping_pong_unit_path) + except Exception as e_clean_pp: self.message_manager.add_warning(f"Could not remove temp ping-pong unit: {e_clean_pp}") + gc.collect() + + def _tb_get_video_dimensions(self, video_path): + video_width, video_height = 0, 0 + # Prefer ffprobe if available for dimensions + if self.has_ffprobe: + try: + probe_cmd = [self.ffprobe_exe, "-v", "error", "-select_streams", "v:0", + "-show_entries", "stream=width,height", "-of", "csv=s=x:p=0", video_path] + result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True, errors='ignore') + w_str, h_str = result.stdout.strip().split('x') + video_width, video_height = int(w_str), int(h_str) + if video_width > 0 and video_height > 0: return video_width, video_height + except Exception as e_probe_dim: + self.message_manager.add_warning(f"ffprobe failed to get dimensions ({e_probe_dim}), trying imageio.") + + # Fallback to imageio + reader = None + try: + reader = imageio.get_reader(video_path) + meta = reader.get_meta_data() + size_imgio = meta.get('size') + if size_imgio and isinstance(size_imgio, tuple) and len(size_imgio) == 2: + video_width, video_height = int(size_imgio[0]), int(size_imgio[1]) + except Exception as e_meta: + self.message_manager.add_warning(f"Error getting video dimensions for vignette (imageio): {e_meta}. Defaulting aspect to 1/1.") + finally: + if reader: reader.close() + return video_width, video_height # Might be 0,0 if all failed + + def _tb_create_vignette_filter(self, strength_percent, width, height): + min_angle_rad = math.pi / 3.5; max_angle_rad = math.pi / 2 + normalized_strength = strength_percent / 100.0 + angle_rad = min_angle_rad + normalized_strength * (max_angle_rad - min_angle_rad) + vignette_aspect_ratio_val = "1/1" + if width > 0 and height > 0: vignette_aspect_ratio_val = f"{width/height:.4f}" + return f"vignette=angle={angle_rad:.4f}:mode=forward:eval=init:aspect={vignette_aspect_ratio_val}" + + def tb_apply_filters(self, video_path, brightness, contrast, saturation, temperature, + sharpen, blur, denoise, vignette, s_curve_contrast, film_grain_strength, + progress=gr.Progress()): + if video_path is None: self.message_manager.add_warning("No input video for filters."); return None + if not self.has_ffmpeg: # FFmpeg is essential for this function + self.message_manager.add_error("FFmpeg is required for applying video filters. This operation cannot proceed.") + return video_path + + progress(0, desc="Initializing filter application...") + resolved_video_path = str(Path(video_path).resolve()) + output_path = self._tb_generate_output_path(resolved_video_path, "filtered", self.toolbox_video_output_dir) + self.message_manager.add_message(f"🎨 Applying filters to {os.path.basename(resolved_video_path)}...") + + video_width, video_height = 0,0 + if vignette > 0: # Only get dimensions if vignette is used + video_width, video_height = self._tb_get_video_dimensions(resolved_video_path) + if video_width > 0 and video_height > 0: self.message_manager.add_message(f"Video dimensions for vignette: {video_width}x{video_height}", "DEBUG") + + filters, applied_filter_descriptions = [], [] + + # Filter definitions + if denoise > 0: filters.append(f"hqdn3d={denoise*0.8:.1f}:{denoise*0.6:.1f}:{denoise*0.7:.1f}:{denoise*0.5:.1f}"); applied_filter_descriptions.append(f"Denoise (hqdn3d)") + if temperature != 0: mid_shift = (temperature/100.0)*0.3; filters.append(f"colorbalance=rm={mid_shift:.2f}:bm={-mid_shift:.2f}"); applied_filter_descriptions.append(f"Color Temp") + eq_parts = []; desc_eq = [] + if brightness != 0: eq_parts.append(f"brightness={brightness/100.0:.2f}"); desc_eq.append(f"Brightness") + if contrast != 1: eq_parts.append(f"contrast={contrast:.2f}"); desc_eq.append(f"Contrast (Linear)") + if saturation != 1: eq_parts.append(f"saturation={saturation:.2f}"); desc_eq.append(f"Saturation") + if eq_parts: filters.append(f"eq={':'.join(eq_parts)}"); applied_filter_descriptions.append(" & ".join(desc_eq)) + if s_curve_contrast > 0: s = s_curve_contrast/100.0; y1 = 0.25-s*(0.25-0.10); y2 = 0.75+s*(0.90-0.75); filters.append(f"curves=all='0/0 0.25/{y1:.2f} 0.75/{y2:.2f} 1/1'"); applied_filter_descriptions.append(f"S-Curve Contrast") + if blur > 0: filters.append(f"gblur=sigma={blur*0.4:.1f}"); applied_filter_descriptions.append(f"Blur") + if sharpen > 0: filters.append(f"unsharp=luma_msize_x=5:luma_msize_y=5:luma_amount={sharpen*0.3:.2f}"); applied_filter_descriptions.append(f"Sharpen") + if film_grain_strength > 0: filters.append(f"noise=alls={film_grain_strength*0.5:.1f}:allf=t+u"); applied_filter_descriptions.append(f"Film Grain") + if vignette > 0: filters.append(self._tb_create_vignette_filter(vignette, video_width, video_height)); applied_filter_descriptions.append(f"Vignette") + + # --- CORRECTED LOGIC --- + if applied_filter_descriptions: + self.message_manager.add_message("🔧 Applying FFmpeg filters: " + ", ".join(applied_filter_descriptions)) + else: + self.message_manager.add_message("ℹ️ No filters selected. Passing video through (re-encoding).") + + progress(0.2, desc="Preparing filter command...") + original_video_has_audio = self._tb_has_audio_stream(resolved_video_path) + + try: + ffmpeg_cmd = [ + self.ffmpeg_exe, "-y", "-loglevel", "error", "-i", resolved_video_path + ] + # Conditionally add the video filter flag only if there are filters to apply + if filters: + ffmpeg_cmd.extend(["-vf", ",".join(filters)]) + + # Add the rest of the encoding options + ffmpeg_cmd.extend([ + "-c:v", "libx264", "-preset", "medium", "-crf", "20", + "-pix_fmt", "yuv420p", + "-map", "0:v:0" + ]) + + if original_video_has_audio: + self.message_manager.add_message("Original video has audio. Re-encoding to AAC for filtered video.", "INFO") + ffmpeg_cmd.extend(["-c:a", "aac", "-b:a", "192k", "-map", "0:a:0?"]) + else: + self.message_manager.add_message("No audio in original or detection issue. Filtered video will be silent.", "INFO") + ffmpeg_cmd.extend(["-an"]) + + ffmpeg_cmd.append(output_path) + + self.message_manager.add_message("🔄 Processing with FFmpeg...") + progress(0.5, desc="Running FFmpeg for filters...") + subprocess.run(ffmpeg_cmd, check=True, capture_output=True, text=True, errors='ignore') + + progress(1.0, desc="Filters applied successfully.") + self.message_manager.add_success(f"✅ Filter step complete! Output: {output_path}") + return output_path + except subprocess.CalledProcessError as e_filters: + self._tb_log_ffmpeg_error(e_filters, "filter application") + progress(1.0, desc="Error applying filters."); return None + except Exception as e: + self.message_manager.add_error(f"❌ An unexpected error occurred: {e}") + import traceback; self.message_manager.add_error(traceback.format_exc()) + progress(1.0, desc="Error applying filters."); return None + finally: gc.collect() + + + def tb_upscale_video(self, video_path, model_key: str, output_scale_factor_ui: float, + tile_size: int, enhance_face: bool, + denoise_strength_ui: float | None, + use_streaming: bool, # New parameter from UI + progress=gr.Progress()): + if video_path is None: self.message_manager.add_warning("No input video for upscaling."); return None + + reader = None + writer = None + final_output_path = None + video_stream_output_path = None + + try: + # --- Model Loading and Setup --- + if model_key not in self.esrgan_upscaler.supported_models: + self.message_manager.add_error(f"Upscale model key '{model_key}' not found in supported models."); return None + + model_native_scale = self.esrgan_upscaler.supported_models[model_key].get('scale', 0) + tile_size_str_for_log = str(tile_size) if tile_size > 0 else "Auto" + face_enhance_str_for_log = "+FaceEnhance" if enhance_face else "" + denoise_str_for_log = "" + if model_key == "RealESR-general-x4v3" and denoise_strength_ui is not None: + denoise_str_for_log = f", DNI: {denoise_strength_ui:.2f}" + + self.message_manager.add_message( + f"Preparing to load ESRGAN model '{model_key}' for {output_scale_factor_ui:.2f}x target upscale " + f"(Native: {model_native_scale}x, Tile: {tile_size_str_for_log}{face_enhance_str_for_log}{denoise_str_for_log})." + ) + progress(0.05, desc=f"Loading ESRGAN model '{model_key}'...") + + upsampler_instance = self.esrgan_upscaler.load_model( + model_key=model_key, + tile_size=tile_size, + denoise_strength=denoise_strength_ui if model_key == "RealESR-general-x4v3" else None + ) + if not upsampler_instance: + self.message_manager.add_error(f"Could not load ESRGAN model '{model_key}'. Aborting."); return None + + if enhance_face: + if not self.esrgan_upscaler._load_face_enhancer(bg_upsampler=upsampler_instance): + self.message_manager.add_warning("GFPGAN load failed. Proceeding without face enhancement.") + enhance_face = False + + self.message_manager.add_message(f"ESRGAN model '{model_key}' loaded. Initializing process...") + + resolved_video_path = str(Path(video_path).resolve()) + reader = imageio.get_reader(resolved_video_path) + meta_data = reader.get_meta_data() + original_fps = meta_data.get('fps', 30.0) + + # --- Define output paths --- + temp_video_suffix_base = f"upscaled_{model_key}{'_FaceEnhance' if enhance_face else ''}" + if model_key == "RealESR-general-x4v3" and denoise_strength_ui is not None: + temp_video_suffix_base += f"_dni{denoise_strength_ui:.2f}" + temp_video_suffix = temp_video_suffix_base.replace(".","p") + "_temp_video" + video_stream_output_path = self._tb_generate_output_path(resolved_video_path, temp_video_suffix, self.toolbox_video_output_dir) + final_muxed_output_path = video_stream_output_path.replace("_temp_video", "") + + self.message_manager.add_message( + f"User selected {'Streaming (low memory)' if use_streaming else 'In-Memory (fast)'} mode." + ) + + # --- PROCESSING BLOCK --- + if use_streaming: + # --- STREAMING (LOW MEMORY) PATH --- + self.message_manager.add_message("Processing frame-by-frame...") + n_frames = self._tb_get_video_frame_count(resolved_video_path) + + if n_frames is None: + self.message_manager.add_error("Cannot use streaming mode because the total number of frames could not be determined. Aborting.") + return None + + writer = imageio.get_writer(video_stream_output_path, fps=original_fps, quality=VIDEO_QUALITY, macro_block_size=None) + + # Use a range-based loop and get_data() instead of iterating the reader directly + for i in progress.tqdm(range(n_frames), desc="Upscaling Frames (Streaming)"): + frame_np = reader.get_data(i) # Explicitly get frame i + + upscaled_frame_np = self.esrgan_upscaler.upscale_frame(frame_np, model_key, float(output_scale_factor_ui), enhance_face) + if upscaled_frame_np is not None: + writer.append_data(upscaled_frame_np) + else: + self.message_manager.add_error(f"Failed to upscale frame {i}. Skipping.") + if "out of memory" in self.message_manager.get_recent_errors_as_str(count=1).lower(): + self.message_manager.add_error("CUDA OOM. Aborting video upscale."); return None + + # We can be more aggressive with GC in streaming mode + if (i + 1) % 10 == 0: + gc.collect() + + writer.close() + writer = None + else: + # --- IN-MEMORY (FAST) PATH --- + self.message_manager.add_message("Processing all frames in memory...") + all_frames = [frame for frame in progress.tqdm(reader, desc="Reading all frames")] + upscaled_frames = [] + frame_iterator = progress.tqdm(all_frames, desc="Upscaling Frames (In-Memory)") + + for frame_np in frame_iterator: + upscaled_frame_np = self.esrgan_upscaler.upscale_frame(frame_np, model_key, float(output_scale_factor_ui), enhance_face) + if upscaled_frame_np is not None: + upscaled_frames.append(upscaled_frame_np) + else: + if "out of memory" in self.message_manager.get_recent_errors_as_str(count=1).lower(): + self.message_manager.add_error("CUDA OOM. Aborting video upscale."); return None + + self.message_manager.add_message("Writing upscaled video file...") + imageio.mimwrite(video_stream_output_path, upscaled_frames, fps=original_fps, quality=VIDEO_QUALITY, macro_block_size=None) + + # --- Teardown and Audio Muxing --- + reader.close() + reader = None + self.message_manager.add_message(f"Upscaled video stream saved to: {video_stream_output_path}") + progress(0.85, desc="Upscaled video stream saved.") + + final_output_path = final_muxed_output_path + can_process_audio = self.has_ffmpeg + original_video_has_audio = self._tb_has_audio_stream(resolved_video_path) if can_process_audio else False + + if can_process_audio and original_video_has_audio: + progress(0.90, desc="Muxing audio...") + self.message_manager.add_message("Original video has audio. Muxing audio with FFmpeg...") + ffmpeg_mux_cmd = [ + self.ffmpeg_exe, "-y", "-loglevel", "error", + "-i", video_stream_output_path, "-i", resolved_video_path, + "-c:v", "copy", "-c:a", "aac", "-b:a", "192k", + "-map", "0:v:0", "-map", "1:a:0?", "-shortest", final_muxed_output_path + ] + try: + subprocess.run(ffmpeg_mux_cmd, check=True, capture_output=True, text=True) + self.message_manager.add_success(f"Upscaled video saved with audio: {final_muxed_output_path}") + except subprocess.CalledProcessError as e_mux: + self._tb_log_ffmpeg_error(e_mux, "audio muxing for upscaled video") + if os.path.exists(final_muxed_output_path): os.remove(final_muxed_output_path) + os.rename(video_stream_output_path, final_muxed_output_path) + else: + if original_video_has_audio and not can_process_audio: + self.message_manager.add_warning("Original video has audio, but FFmpeg is not available to process it. Upscaled output will be silent.") + if os.path.exists(final_muxed_output_path) and final_muxed_output_path != video_stream_output_path: + os.remove(final_muxed_output_path) + os.rename(video_stream_output_path, final_muxed_output_path) + + progress(1.0, desc="Upscaling complete.") + self.message_manager.add_success(f"Video upscaling complete: {final_output_path}") + return final_output_path + + except Exception as e: + self.message_manager.add_error(f"Error during video upscaling: {e}") + import traceback; self.message_manager.add_error(traceback.format_exc()) + progress(1.0, desc="Error during upscaling."); return None + finally: + if reader and not reader.closed: reader.close() + if writer and not writer.closed: writer.close() + if video_stream_output_path and os.path.exists(video_stream_output_path) and final_output_path and video_stream_output_path != final_output_path: + try: os.remove(video_stream_output_path) + except Exception as e_clean: self.message_manager.add_warning(f"Could not remove temp upscaled video: {e_clean}") + + if self.esrgan_upscaler: + self.esrgan_upscaler.unload_model(model_key) + if enhance_face: + self.esrgan_upscaler._unload_face_enhancer() + devicetorch.empty_cache(torch); gc.collect() + + def tb_open_output_folder(self): + folder_path = os.path.abspath(self.postprocessed_output_root_dir) + try: + os.makedirs(folder_path, exist_ok=True) + if sys.platform == 'win32': subprocess.run(['explorer', folder_path]) + elif sys.platform == 'darwin': subprocess.run(['open', folder_path]) + else: subprocess.run(['xdg-open', folder_path]) + self.message_manager.add_success(f"Opened postprocessed output folder: {folder_path}") + except Exception as e: + self.message_manager.add_error(f"Error opening folder {folder_path}: {e}") + + def _tb_clean_directory(self, dir_path, dir_description): + """ + Helper to clean a single temp directory and return a single, formatted status line. + """ + LABEL_WIDTH = 32 # Width for the description label for alignment + status_icon = "✅" + status_text = "" + + # Make path relative for cleaner logging + try: + display_path = os.path.relpath(dir_path, self.project_root) if dir_path else "N/A" + except (ValueError, TypeError): + display_path = str(dir_path) # Fallback if path is weird + + if not dir_path or not os.path.exists(dir_path): + status_icon = "ℹ️" + status_text = "Path not found or not set." + return f"[{status_icon}] {dir_description:<{LABEL_WIDTH}} : {status_text}" + + try: + items = os.listdir(dir_path) + file_count = sum(1 for item in items if os.path.isfile(os.path.join(dir_path, item))) + dir_count = sum(1 for item in items if os.path.isdir(os.path.join(dir_path, item))) + + if file_count == 0 and dir_count == 0: + status_text = f"Already empty at '{display_path}'" + else: + shutil.rmtree(dir_path) + + # --- Dynamic String Building --- + summary_parts = [] + if file_count > 0: + summary_parts.append(f"{file_count} file{'s' if file_count != 1 else ''}") + if dir_count > 0: + summary_parts.append(f"{dir_count} folder{'s' if dir_count != 1 else ''}") + + status_text = f"Cleaned ({' and '.join(summary_parts)}) from '{display_path}'" + + os.makedirs(dir_path, exist_ok=True) + + except Exception as e: + status_icon = "❌" + status_text = f"ERROR cleaning '{display_path}': {e}" + + return f"[{status_icon}] {dir_description:<{LABEL_WIDTH}} : {status_text}" + + def tb_clear_temporary_files(self): + """ + Clears all temporary file locations and returns a formatted summary string. + """ + # 1. Clean Post-processing Temp Folder + postproc_temp_dir = self._base_temp_output_dir + postproc_summary_line = self._tb_clean_directory(postproc_temp_dir, "Post-processing temp folder") + + # 2. Clean Gradio Temp Folder + gradio_temp_dir = self.settings.get("gradio_temp_dir") + gradio_summary_line = self._tb_clean_directory(gradio_temp_dir, "Gradio temp folder") + + # Join the individual lines into a single string for printing + return f"{postproc_summary_line}\n{gradio_summary_line}" diff --git a/modules/toolbox_app.py b/modules/toolbox_app.py new file mode 100644 index 0000000000000000000000000000000000000000..2977470ccbae3d1d385a70ca7be836e1c036639c --- /dev/null +++ b/modules/toolbox_app.py @@ -0,0 +1,1863 @@ +import gc +import json # for preset loading/saving +import os +import psutil +import sys +import traceback +import types + +# --- Standalone Startup & Path Fix --- +# This block runs only when the script is executed directly. +# It sets up the environment for standalone operation. +if __name__ == '__main__': + # Adjust the Python path to include the project root, so local imports work. + modules_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(modules_dir) + if project_root not in sys.path: + print(f"--- Running Toolbox in Standalone Mode ---") + print(f"Adding project root to sys.path: {project_root}") + sys.path.insert(0, project_root) + + # Set the GRADIO_TEMP_DIR *before* Gradio is imported. + # This forces the standalone app to use the same temp folder as the main app. + from modules.settings import Settings + _settings_for_env = Settings() + _gradio_temp_dir = _settings_for_env.get("gradio_temp_dir") + if _gradio_temp_dir: + os.environ['GRADIO_TEMP_DIR'] = os.path.abspath(_gradio_temp_dir) + print(f"Set GRADIO_TEMP_DIR for standalone mode: {os.environ['GRADIO_TEMP_DIR']}") + del _settings_for_env, _gradio_temp_dir + + # Suppress persistent Windows asyncio proactor errors when running standalone. + if os.name == 'nt': + import asyncio + from functools import wraps + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + def silence_event_loop_closed(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except RuntimeError as e: + if str(e) != 'Event loop is closed': raise + return wrapper + if hasattr(asyncio.proactor_events._ProactorBasePipeTransport, '_call_connection_lost'): + asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost = silence_event_loop_closed( + asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost) + +# --- Third-Party Library Imports --- +import devicetorch +import gradio as gr +import imageio # Added for reading frame dimensions +import torch +from torchvision.transforms.functional import rgb_to_grayscale + +# --- Patch for basicsr (must run after torchvision import) --- +functional_tensor_mod = types.ModuleType('functional_tensor') +functional_tensor_mod.rgb_to_grayscale = rgb_to_grayscale +sys.modules.setdefault('torchvision.transforms.functional_tensor', functional_tensor_mod) + +# --- Local Application Imports --- +from modules.settings import Settings +from modules.toolbox.esrgan_core import ESRGANUpscaler +from modules.toolbox.message_manager import MessageManager +from modules.toolbox.rife_core import RIFEHandler +from modules.toolbox.setup_ffmpeg import setup_ffmpeg +from modules.toolbox.system_monitor import SystemMonitor +from modules.toolbox.toolbox_processor import VideoProcessor + +# Attempt to import helper, with a fallback if it's missing. +try: + from diffusers_helper.memory import cpu +except ImportError: + print("WARNING: Could not import cpu from diffusers_helper.memory. Falling back to torch.device('cpu')") + cpu = torch.device('cpu') + +# Check if FFmpeg is set up, if not, run the setup +script_dir = os.path.dirname(os.path.abspath(__file__)) +# Construct the correct path to the target bin directory. +bin_dir = os.path.join(script_dir, 'toolbox', 'bin') + +ffmpeg_exe_name = 'ffmpeg.exe' if sys.platform == "win32" else 'ffmpeg' +ffmpeg_full_path = os.path.join(bin_dir, ffmpeg_exe_name) + +# Check if the executable exists at the correct location. +if not os.path.exists(ffmpeg_full_path): + print(f"Bundled FFmpeg not found in '{bin_dir}'. Running one-time setup...") + setup_ffmpeg() + + +tb_message_mgr = MessageManager() +settings_instance = Settings() +tb_processor = VideoProcessor(tb_message_mgr, settings_instance) # Pass settings to VideoProcessor + +# --- Default Filter Values --- +TB_DEFAULT_FILTER_SETTINGS = { + "brightness": 0, "contrast": 1, "saturation": 1, "temperature": 0, + "sharpen": 0, "blur": 0, "denoise": 0, "vignette": 0, + "s_curve_contrast": 0, "film_grain_strength": 0 +} + +# --- Filter Presets Handling --- +TB_BUILT_IN_PRESETS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "toolbox", "data", "filter_presets.json") +tb_filter_presets_data = {} # Will be populated by _initialize_presets + +def _initialize_presets(): + global tb_filter_presets_data + default_preset_map_for_creation = { + "none": TB_DEFAULT_FILTER_SETTINGS.copy(), + "cinematic": {"brightness": -5, "contrast": 1.3, "saturation": 0.9, "temperature": 20, "vignette": 10, "sharpen": 1.2, "blur": 0, "denoise": 0, "s_curve_contrast": 15, "film_grain_strength": 5}, + "vintage": {"brightness": 5, "contrast": 1.1, "saturation": 0.7, "temperature": 15, "vignette": 30, "sharpen": 0, "blur": 0.5, "denoise": 0, "s_curve_contrast": 10, "film_grain_strength": 10}, + "cool": {"brightness": 0, "contrast": 1.2, "saturation": 1.1, "temperature": -15, "vignette": 0, "sharpen": 1.0, "blur": 0, "denoise": 0, "s_curve_contrast": 5, "film_grain_strength": 0}, + "warm": {"brightness": 5, "contrast": 1.1, "saturation": 1.2, "temperature": 20, "vignette": 0, "sharpen": 0, "blur": 0, "denoise": 0, "s_curve_contrast": 5, "film_grain_strength": 0}, + "dramatic": {"brightness": -5, "contrast": 1.2, "saturation": 0.9, "temperature": -10, "vignette": 20, "sharpen": 1.2, "blur": 0, "denoise": 0, "s_curve_contrast": 20, "film_grain_strength": 8} + } + try: + os.makedirs(os.path.dirname(TB_BUILT_IN_PRESETS_FILE), exist_ok=True) + if not os.path.exists(TB_BUILT_IN_PRESETS_FILE): + tb_message_mgr.add_message(f"Presets file not found. Creating with default presets: {TB_BUILT_IN_PRESETS_FILE}", "INFO") + with open(TB_BUILT_IN_PRESETS_FILE, 'w') as f: + json.dump(default_preset_map_for_creation, f, indent=4) + tb_filter_presets_data = default_preset_map_for_creation + tb_message_mgr.add_success("Default presets file created.") + else: + with open(TB_BUILT_IN_PRESETS_FILE, 'r') as f: + tb_filter_presets_data = json.load(f) + # Ensure "none" preset always exists and uses TB_DEFAULT_FILTER_SETTINGS + if "none" not in tb_filter_presets_data or tb_filter_presets_data["none"] != TB_DEFAULT_FILTER_SETTINGS: + tb_filter_presets_data["none"] = TB_DEFAULT_FILTER_SETTINGS.copy() + # Optionally re-save if "none" was missing or incorrect, or just use in-memory fix + # with open(TB_BUILT_IN_PRESETS_FILE, 'w') as f: + # json.dump(tb_filter_presets_data, f, indent=4) + + tb_message_mgr.add_message(f"Filter presets loaded from {TB_BUILT_IN_PRESETS_FILE}.", "INFO") + except Exception as e: + tb_message_mgr.add_error(f"Error with filter presets file {TB_BUILT_IN_PRESETS_FILE}: {e}. Using in-memory defaults.") + tb_filter_presets_data = default_preset_map_for_creation +_initialize_presets() # Call once when the script module is loaded + +def tb_update_messages(): + return tb_message_mgr.get_messages() + +def tb_handle_update_monitor(monitor_enabled): # This updates the TOOLBOX TAB's monitor + if not monitor_enabled: + return gr.update() # Do nothing if disabled to save resources. + return SystemMonitor.get_system_info() + +def tb_handle_analyze_video(video_path): + tb_message_mgr.clear() + analysis = tb_processor.tb_analyze_video_input(video_path) + # Return a third value to control the accordion's 'open' state + return tb_update_messages(), analysis, gr.update(open=True) + +def tb_handle_process_frames(video_path, fps_mode, speed_factor, use_streaming, progress=gr.Progress()): + tb_message_mgr.clear() + output_video = tb_processor.tb_process_frames(video_path, fps_mode, speed_factor, use_streaming, progress) + return output_video, tb_update_messages() + +def tb_handle_create_loop(video_path, loop_type, num_loops, progress=gr.Progress()): + tb_message_mgr.clear() + output_video = tb_processor.tb_create_loop(video_path, loop_type, num_loops, progress) + return output_video, tb_update_messages() + +def tb_update_filter_sliders_from_preset(preset_name): + preset_settings = tb_filter_presets_data.get(preset_name) + if not preset_settings: + tb_message_mgr.add_warning(f"Preset '{preset_name}' not found. Using 'none' settings.") + preset_settings = tb_filter_presets_data.get("none", TB_DEFAULT_FILTER_SETTINGS.copy()) + + final_settings = TB_DEFAULT_FILTER_SETTINGS.copy() + final_settings.update(preset_settings) + + ordered_values = [] + for key in TB_DEFAULT_FILTER_SETTINGS.keys(): + ordered_values.append(final_settings.get(key, TB_DEFAULT_FILTER_SETTINGS[key])) + + return tuple(ordered_values) + +def tb_handle_reset_all_filters(): + tb_message_mgr.add_message("Filter sliders reset to default 'none' values.") + none_settings_values = tb_update_filter_sliders_from_preset("none") + return "none", "", *none_settings_values, tb_update_messages() + +def tb_handle_save_user_preset(new_preset_name_str, *slider_values): + global tb_filter_presets_data; tb_message_mgr.clear() + if not new_preset_name_str or not new_preset_name_str.strip(): + tb_message_mgr.add_warning("Preset name cannot be empty."); return gr.update(), tb_update_messages(), gr.update() + + clean_preset_name = new_preset_name_str.strip() + + if clean_preset_name.lower() == "none": + tb_message_mgr.add_warning("'none' is a protected preset and cannot be overwritten.") + return gr.update(), tb_update_messages(), gr.update(value="") # Clear input box + + new_preset_values = dict(zip(TB_DEFAULT_FILTER_SETTINGS.keys(), slider_values)) + preset_existed = clean_preset_name in tb_filter_presets_data + tb_filter_presets_data[clean_preset_name] = new_preset_values + try: + with open(TB_BUILT_IN_PRESETS_FILE, 'w') as f: json.dump(tb_filter_presets_data, f, indent=4) + tb_message_mgr.add_success(f"Preset '{clean_preset_name}' {'updated' if preset_existed else 'saved'} successfully!") + + updated_choices = list(tb_filter_presets_data.keys()) + if "none" in updated_choices: updated_choices.remove("none"); updated_choices.sort(); updated_choices.insert(0, "none") + else: updated_choices.sort() + + return gr.update(choices=updated_choices, value=clean_preset_name), tb_update_messages(), "" + except Exception as e: + tb_message_mgr.add_error(f"Error saving preset '{clean_preset_name}': {e}") + _initialize_presets() + return gr.update(), tb_update_messages(), gr.update(value=new_preset_name_str) + +def tb_handle_delete_user_preset(preset_name_to_delete): + global tb_filter_presets_data; tb_message_mgr.clear() + if not preset_name_to_delete or not preset_name_to_delete.strip(): + tb_message_mgr.add_warning("No preset name to delete (select from dropdown or type)."); return gr.update(), tb_update_messages(), gr.update(), *tb_update_filter_sliders_from_preset("none") + + clean_preset_name = preset_name_to_delete.strip() + if clean_preset_name.lower() == "none": + tb_message_mgr.add_warning("'none' preset cannot be deleted."); return gr.update(), tb_update_messages(), gr.update(value="none"), *tb_update_filter_sliders_from_preset("none") + if clean_preset_name not in tb_filter_presets_data: + tb_message_mgr.add_warning(f"Preset '{clean_preset_name}' not found."); return gr.update(), tb_update_messages(), gr.update(), *tb_update_filter_sliders_from_preset("none") + + del tb_filter_presets_data[clean_preset_name] + try: + with open(TB_BUILT_IN_PRESETS_FILE, 'w') as f: json.dump(tb_filter_presets_data, f, indent=4) + tb_message_mgr.add_success(f"Preset '{clean_preset_name}' deleted.") + + updated_choices = list(tb_filter_presets_data.keys()) + if "none" in updated_choices: updated_choices.remove("none"); updated_choices.sort(); updated_choices.insert(0, "none") + else: updated_choices.sort() + + sliders_reset_values = tb_update_filter_sliders_from_preset("none") + return gr.update(choices=updated_choices, value="none"), tb_update_messages(), "", *sliders_reset_values + except Exception as e: + tb_message_mgr.add_error(f"Error deleting preset '{clean_preset_name}' from file: {e}") + _initialize_presets(); + current_choices = list(tb_filter_presets_data.keys()) + if "none" in current_choices: current_choices.remove("none"); current_choices.sort(); current_choices.insert(0, "none") + else: current_choices.sort() + selected_val_after_error = clean_preset_name if clean_preset_name in current_choices else "none" + sliders_after_error_values = tb_update_filter_sliders_from_preset(selected_val_after_error) + return gr.update(choices=current_choices, value=selected_val_after_error), tb_update_messages(), gr.update(value=selected_val_after_error), *sliders_after_error_values + +# --- Workflow Presets Handling --- +TB_WORKFLOW_PRESETS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "toolbox", "data", "workflow_presets.json") +tb_workflow_presets_data = {} # Will be populated by _initialize_workflow_presets + +# This helper function creates a dictionary of all default parameter values +def _get_default_workflow_params(): + # Gets default values from filter settings and adds other op defaults + params = TB_DEFAULT_FILTER_SETTINGS.copy() + params.update({ + "upscale_model": list(tb_processor.esrgan_upscaler.supported_models.keys())[0] if tb_processor.esrgan_upscaler.supported_models else None, + "upscale_factor": 2.0, + "tile_size": 0, + "enhance_face": False, + "denoise_strength": 0.5, + "upscale_use_streaming": False, + "frames_use_streaming": False, + "fps_mode": "No Interpolation", + "speed_factor": 1.0, + "loop_type": "loop", + "num_loops": 1, + "export_format": "MP4", + "export_quality": 85, + "export_max_width": 1024, + }) + return params + +def _initialize_workflow_presets(): + global tb_workflow_presets_data + # The 'None' preset stores default values for all controls and no active steps + default_workflow_map = { + "None": { + "active_steps": [], + "params": _get_default_workflow_params() + } + } + try: + # Ensure the directory exists + os.makedirs(os.path.dirname(TB_WORKFLOW_PRESETS_FILE), exist_ok=True) + if not os.path.exists(TB_WORKFLOW_PRESETS_FILE): + tb_message_mgr.add_message(f"Workflow presets file not found. Creating with a default 'None' preset: {TB_WORKFLOW_PRESETS_FILE}", "INFO") + with open(TB_WORKFLOW_PRESETS_FILE, 'w') as f: + json.dump(default_workflow_map, f, indent=4) + tb_workflow_presets_data = default_workflow_map + else: + with open(TB_WORKFLOW_PRESETS_FILE, 'r') as f: + tb_workflow_presets_data = json.load(f) + # Ensure "None" preset always exists and is up-to-date + tb_workflow_presets_data["None"] = default_workflow_map["None"] + tb_message_mgr.add_message(f"Workflow presets loaded from {TB_WORKFLOW_PRESETS_FILE}.", "INFO") + except Exception as e: + tb_message_mgr.add_error(f"Error with workflow presets file {TB_WORKFLOW_PRESETS_FILE}: {e}. Using in-memory defaults.") + tb_workflow_presets_data = default_workflow_map + +_initialize_workflow_presets() # Call once when the script module is loaded + +def tb_handle_apply_filters(video_path, brightness, contrast, saturation, temperature, + sharpen, blur, denoise, vignette, + s_curve_contrast, film_grain_strength, + progress=gr.Progress()): + tb_message_mgr.clear() + output_video = tb_processor.tb_apply_filters(video_path, brightness, contrast, saturation, temperature, + sharpen, blur, denoise, vignette, + s_curve_contrast, film_grain_strength, progress) + return output_video, tb_update_messages() + +def tb_handle_reassemble_frames( + frames_source_folder, + output_fps, + output_video_name, + progress=gr.Progress() +): + tb_message_mgr.clear() + tb_message_mgr.add_message("Preparing to reassemble from Frames Studio...") + if not frames_source_folder: + tb_message_mgr.add_warning("No source folder selected in the Frames Studio dropdown.") + return None, tb_update_messages() + + frames_path_to_use = os.path.join(tb_processor.extracted_frames_target_path, frames_source_folder) + source_description = f"Frames Studio folder '{frames_source_folder}'" + + if not os.path.isdir(frames_path_to_use): + tb_message_mgr.add_error(f"Selected folder not found: {frames_path_to_use}") + return None, tb_update_messages() + + tb_message_mgr.add_message(f"Attempting to reassemble frames from {source_description}.") + output_video = tb_processor.tb_reassemble_frames_to_video( + frames_path_to_use, + output_fps, + output_base_name_override=output_video_name, + progress=progress + ) + return output_video, tb_update_messages() + +def tb_handle_extract_frames(video_path, extraction_rate, progress=gr.Progress()): + tb_message_mgr.clear() + tb_processor.tb_extract_frames(video_path, int(extraction_rate), progress) + return tb_update_messages() + +def tb_handle_refresh_extracted_folders(): + folders = tb_processor.tb_get_extracted_frame_folders() + clear_btn_update = gr.update(interactive=False) + # When refreshing, clear the gallery and info box + return gr.update(choices=folders, value=None), tb_update_messages(), clear_btn_update, None, "Select a folder and click 'Load'." + +def tb_handle_clear_selected_folder(selected_folder_to_delete): + tb_message_mgr.clear() + if not selected_folder_to_delete: + tb_message_mgr.add_warning("No folder selected from the dropdown to delete.") + return tb_update_messages(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() + + success = tb_processor.tb_delete_extracted_frames_folder(selected_folder_to_delete) + updated_folders = tb_processor.tb_get_extracted_frame_folders() + + # Return updates for all components: messages, dropdown, gallery, info box, and the two frame action buttons. + return ( + tb_update_messages(), + gr.update(choices=updated_folders, value=None), # Update dropdown + None, # Clear the gallery + None, # Clear the info box + gr.update(interactive=False), # Disable save button + gr.update(interactive=False) # Disable delete button + ) + +def tb_handle_load_frames_to_studio(selected_folder): + tb_message_mgr.clear() + if not selected_folder: + tb_message_mgr.add_warning("No folder selected to load into the studio.") + return tb_update_messages(), None, "Select a folder and click 'Load'." + + frame_files_list = tb_processor.tb_get_frames_from_folder(selected_folder) + + if not frame_files_list: + tb_message_mgr.add_warning(f"No image files found in '{selected_folder}'.") + return tb_update_messages(), None, "No frames found in this folder." + + tb_message_mgr.add_success(f"Loaded {len(frame_files_list)} frames from '{selected_folder}' into the studio.") + return tb_update_messages(), frame_files_list, "Select a frame from the gallery." + +def tb_handle_frame_select(evt: gr.SelectData): + """Handles frame selection in the gallery, providing more detailed info.""" + # The gallery's select event (evt.value) is a dictionary like: + # {'image': {'path': '...', 'url': '...'}, 'caption': None} + if evt.value and 'image' in evt.value and 'path' in evt.value['image']: + # CORRECT WAY to access the path string + selected_image_path = evt.value['image']['path'] + + # Now 'selected_image_path' is a string, so os.path.basename will work correctly. + filename = os.path.basename(selected_image_path) + + info_text = f"File: {filename}" + try: + # Add image dimensions to the info box for better context + img = imageio.imread(selected_image_path) + h, w, *_ = img.shape + info_text += f"\nDimensions: {w}x{h}" + except Exception as e: + tb_message_mgr.add_warning(f"Could not read dimensions for {filename}: {e}") + + # 1. Return the info string to the gr.Textbox component. + # 2 & 3. Return updates to enable the buttons. + return info_text, gr.update(interactive=True), gr.update(interactive=True) + else: + # This part handles deselection or malformed event data + return "Select a frame...", gr.update(interactive=False), gr.update(interactive=False) + +def _get_frame_path_from_ui(selected_folder, frame_info_str): + """Helper to safely parse UI components to get a full file path.""" + if not selected_folder or not frame_info_str: + return None, "Missing folder or frame selection." + + # Extract filename from the first line of the info string "File: frame_0000.png" + first_line = frame_info_str.splitlines()[0] + if not first_line.startswith("File: "): + return None, "Invalid frame info format." + + filename_to_process = first_line.replace("File: ", "").strip() + full_path = os.path.join(tb_processor.extracted_frames_target_path, selected_folder, filename_to_process) + return full_path, None + +def tb_handle_delete_and_refresh_gallery(selected_folder, frame_info_str): + """ + Deletes the selected frame, gets the updated frame list, and explicitly + determines the next frame to select to create a seamless workflow. + """ + if not selected_folder: + tb_message_mgr.add_warning("Cannot delete frame: No folder selected.") + return gr.update(), gr.update(), gr.update(), gr.update(), tb_update_messages() + + full_path_to_delete, error = _get_frame_path_from_ui(selected_folder, frame_info_str) + if error: + tb_message_mgr.add_error(f"Could not identify frame to delete: {error}") + return gr.update(), gr.update(), gr.update(), gr.update(), tb_update_messages() + + old_frame_list = tb_processor.tb_get_frames_from_folder(selected_folder) + try: + deleted_index = old_frame_list.index(full_path_to_delete) + except ValueError: + tb_message_mgr.add_error(f"Consistency error: Frame '{os.path.basename(full_path_to_delete)}' not found in its folder's frame list before deletion.") + # As a fallback, just delete and refresh to a safe state + tb_processor.tb_delete_single_frame(full_path_to_delete) + updated_frame_list_fallback = tb_processor.tb_get_frames_from_folder(selected_folder) + return updated_frame_list_fallback, "Select a frame...", gr.update(interactive=False), gr.update(interactive=False), tb_update_messages() + + tb_processor.tb_delete_single_frame(full_path_to_delete) # This logs the success message + new_frame_list = tb_processor.tb_get_frames_from_folder(selected_folder) + + if not new_frame_list: + # The folder is now empty + info_text = "All frames have been deleted." + return [], info_text, gr.update(interactive=False), gr.update(interactive=False), tb_update_messages() + else: + # Determine the index of the next frame to highlight. + # This keeps the selection at the same position, or on the new last item if the old last item was deleted. + next_selection_index = min(deleted_index, len(new_frame_list) - 1) + + # Explicitly generate the info for the frame that will now be selected. + new_selected_path = new_frame_list[next_selection_index] + filename = os.path.basename(new_selected_path) + info_text = f"File: {filename}" + try: + img = imageio.imread(new_selected_path) + h, w, *_ = img.shape + info_text += f"\nDimensions: {w}x{h}" + except Exception as e: + tb_message_mgr.add_warning(f"Could not read dimensions for new selection {filename}: {e}") + + # Return all the new state information to update the UI in one go. + return new_frame_list, info_text, gr.update(interactive=True), gr.update(interactive=True), tb_update_messages() + +def tb_handle_save_selected_frame(selected_folder, frame_info_str): + """Handler for the save button. Calls the backend processor.""" + tb_message_mgr.clear() + full_path, error = _get_frame_path_from_ui(selected_folder, frame_info_str) + if error: + tb_message_mgr.add_error(error) + return error, tb_update_messages() + + status_message = tb_processor.tb_save_single_frame(full_path) + # The message is generated inside the processor, we just need to return it. + return status_message, tb_update_messages() + +# --- END: Frames Studio Handlers --- + +# --- START: Workflow Preset Handlers --- + +def tb_handle_save_workflow_preset(preset_name, active_steps, *params): + global tb_workflow_presets_data + tb_message_mgr.clear() + + if not preset_name or not preset_name.strip(): + tb_message_mgr.add_warning("Workflow Preset name cannot be empty.") + return gr.update(), gr.update(value=preset_name), tb_update_messages() + + clean_name = preset_name.strip() + # This key list MUST match the order of components in _ALL_PIPELINE_PARAMS_COMPONENTS_ + param_keys = [ + # Upscale + "upscale_model", "upscale_factor", "tile_size", + "enhance_face", "denoise_strength", "upscale_use_streaming", + # Frame Adjust + "fps_mode", "speed_factor", "frames_use_streaming", + # Loop + "loop_type", "num_loops", + # Filters (using the ordered keys from the constant) + *list(TB_DEFAULT_FILTER_SETTINGS.keys()), + # Export + "export_format", "export_quality", "export_max_width" + ] + + # Pack the parameters into a dictionary + params_dict = dict(zip(param_keys, params)) + + new_preset_data = { + "active_steps": active_steps, + "params": params_dict + } + + preset_existed = clean_name in tb_workflow_presets_data + tb_workflow_presets_data[clean_name] = new_preset_data + + try: + with open(TB_WORKFLOW_PRESETS_FILE, 'w') as f: + json.dump(tb_workflow_presets_data, f, indent=4) + + tb_message_mgr.add_success(f"Workflow Preset '{clean_name}' {'updated' if preset_existed else 'saved'} successfully!") + + # Update dropdown choices + updated_choices = sorted([k for k in tb_workflow_presets_data.keys() if k != "None"]) + updated_choices.insert(0, "None") + + return gr.update(choices=updated_choices, value=clean_name), "", tb_update_messages() + + except Exception as e: + tb_message_mgr.add_error(f"Error saving workflow preset '{clean_name}': {e}") + # Revert to last known good state + _initialize_workflow_presets() + return gr.update(), gr.update(value=preset_name), tb_update_messages() + + except Exception as e: + tb_message_mgr.add_error(f"Error saving workflow preset '{clean_name}': {e}") + # Revert to last known good state + _initialize_workflow_presets() + return gr.update(), gr.update(value=preset_name), tb_update_messages() + +def tb_handle_load_workflow_preset(preset_name): + tb_message_mgr.clear() + preset_data = tb_workflow_presets_data.get(preset_name) + + if not preset_data: + tb_message_mgr.add_warning(f"Workflow preset '{preset_name}' not found. Loading 'None' state.") + preset_data = tb_workflow_presets_data.get("None") + + # Get the default parameter structure to ensure all keys are present + final_params = _get_default_workflow_params() + # Update with the loaded preset's parameters + final_params.update(preset_data.get("params", {})) + + active_steps = preset_data.get("active_steps", []) + + # The order of values returned MUST match the order of components in the event handler's output list + ordered_values = [ + # Checkbox + active_steps, + # Upscale + final_params["upscale_model"], final_params["upscale_factor"], final_params["tile_size"], + final_params["enhance_face"], final_params["denoise_strength"], + final_params["upscale_use_streaming"], + # Frame Adjust + final_params["fps_mode"], final_params["speed_factor"], + final_params["frames_use_streaming"], + # Loop + final_params["loop_type"], final_params["num_loops"], + # Filters (must be in the same order as _ORDERED_FILTER_SLIDERS_) + final_params["brightness"], final_params["contrast"], final_params["saturation"], + final_params["temperature"], final_params["sharpen"], final_params["blur"], + final_params["denoise"], final_params["vignette"], final_params["s_curve_contrast"], + final_params["film_grain_strength"], + # Export + final_params["export_format"], final_params["export_quality"], final_params["export_max_width"] + ] + + tb_message_mgr.add_message(f"Loaded workflow preset: '{preset_name}'") + # Also return the preset name to the input box, and the updated messages + return preset_name, *ordered_values, tb_update_messages() + +def tb_handle_delete_workflow_preset(preset_name): + global tb_workflow_presets_data + tb_message_mgr.clear() + + if not preset_name or not preset_name.strip(): + tb_message_mgr.add_warning("No workflow preset name provided to delete.") + # The number of outputs for an event handler MUST be consistent. + # There are 28 outputs: dropdown, namebox, chkbox, 24 params, message. + # The star-expansion covers the chkbox (1) + params (24) = 25 components. + return gr.update(), gr.update(), *([gr.update()] * 25), tb_update_messages() + + + clean_name = preset_name.strip() + if clean_name == "None": + tb_message_mgr.add_warning("'None' preset cannot be deleted.") + return gr.update(value="None"), gr.update(), *([gr.update()] * 25), tb_update_messages() + + if clean_name not in tb_workflow_presets_data: + tb_message_mgr.add_warning(f"Workflow preset '{clean_name}' not found.") + return gr.update(), gr.update(), *([gr.update()] * 25), tb_update_messages() + + del tb_workflow_presets_data[clean_name] + + try: + with open(TB_WORKFLOW_PRESETS_FILE, 'w') as f: + json.dump(tb_workflow_presets_data, f, indent=4) + tb_message_mgr.add_success(f"Workflow preset '{clean_name}' deleted.") + + updated_choices = sorted([k for k in tb_workflow_presets_data.keys() if k != "None"]) + updated_choices.insert(0, "None") + + # After deleting, load the "None" state to get the reset values + none_state_outputs = tb_handle_load_workflow_preset("None") + + # The rest of the values come from the 'load' function, but we skip its first value (which was also for the textbox) + return gr.update(choices=updated_choices, value="None"), "", *none_state_outputs[1:] + + except Exception as e: + tb_message_mgr.add_error(f"Error deleting workflow preset '{clean_name}': {e}") + _initialize_workflow_presets() # Revert + # On error, we don't know the state, so just update the messages + return gr.update(), gr.update(value=clean_name), *([gr.update()] * 25), tb_update_messages() + +def tb_handle_reset_workflow_to_defaults(): + # This function loads the 'None' preset to get the reset values for most components... + load_outputs = tb_handle_load_workflow_preset("None") + # ...then it PREPENDS an update specifically for the dropdown menu. + # The first value is for the dropdown, the rest are for the components in _WORKFLOW_LOAD_OUTPUTS_ + return gr.update(value="None"), *load_outputs + +# --- END: New Workflow Preset Handlers --- + + +def tb_handle_start_pipeline( + # 1. Active Tab Index + active_tab_index, + # 2. Selected Operations + selected_ops, + # Inputs + single_video_path, batch_video_paths, + # Upscale + model_key, output_scale_factor, tile_size, enhance_face, denoise_strength, upscale_use_streaming, + # Frame Adjust + fps_mode, speed_factor, frames_use_streaming, + # Loop + loop_type, num_loops, + # Filters + brightness, contrast, saturation, temperature, sharpen, blur, denoise, vignette, s_curve_contrast, film_grain_strength, + # Export + export_format, export_quality, export_max_width, + progress=gr.Progress() +): + tb_message_mgr.clear() + input_paths_to_process = [] + + if active_tab_index == 1 and batch_video_paths and len(batch_video_paths) > 0: + # Process batch only if the batch tab is active and it has files + input_paths_to_process = batch_video_paths + tb_message_mgr.add_message(f"Starting BATCH pipeline for {len(input_paths_to_process)} videos (from active Batch tab).") + elif active_tab_index == 0 and single_video_path: + # Process single video only if the single tab is active and it has a video + input_paths_to_process = [single_video_path] + tb_message_mgr.add_message(f"Starting SINGLE video pipeline for {os.path.basename(single_video_path)} (from active Single tab).") + else: + # Handle cases where the active tab is empty + if active_tab_index == 1: + tb_message_mgr.add_warning("Batch Input tab is active, but no files were provided.") + else: # active_tab_index == 0 or default + tb_message_mgr.add_warning("Single Video Input tab is active, but no video was provided.") + return None, tb_update_messages() + + if not selected_ops: + tb_message_mgr.add_warning("No operations selected for the pipeline. Please check at least one box in 'Pipeline Steps'.") + return None, tb_update_messages() + + # Map checkbox labels to operation keys + op_map = { + "upscale": "upscale", + "frames": "frame_adjust", + "loop": "loop", + "filters": "filters", + "export": "export" + } + + # Define the execution order + execution_order = ["upscale", "frame_adjust", "filters", "loop", "export"] + + pipeline_config = {"operations": []} + + # Build the pipeline configuration based on user selections in the correct order + for op_key in execution_order: + # Find the display name from the op_map that corresponds to the current key + display_name = next((d_name for d_name, k_name in op_map.items() if k_name == op_key), None) + + if display_name and display_name in selected_ops: + if op_key == "upscale": + pipeline_config["operations"].append({ + "name": "upscale", + "params": { + "model_key": model_key, + "output_scale_factor_ui": float(output_scale_factor), + "tile_size": int(tile_size), + "enhance_face": enhance_face, + "denoise_strength_ui": denoise_strength, + "use_streaming": upscale_use_streaming + } + }) + elif op_key == "frame_adjust": + pipeline_config["operations"].append({ + "name": "frame_adjust", + "params": { + "target_fps_mode": fps_mode, + "speed_factor": speed_factor, + "use_streaming": frames_use_streaming + } + }) + elif op_key == "loop": + pipeline_config["operations"].append({ + "name": "loop", + "params": { "loop_type": loop_type, "num_loops": num_loops } + }) + elif op_key == "filters": + pipeline_config["operations"].append({ + "name": "filters", + "params": { + "brightness": brightness, "contrast": contrast, "saturation": saturation, "temperature": temperature, + "sharpen": sharpen, "blur": blur, "denoise": denoise, "vignette": vignette, + "s_curve_contrast": s_curve_contrast, "film_grain_strength": film_grain_strength + } + }) + elif op_key == "export": + pipeline_config["operations"].append({ + "name": "export", + "params": { + "export_format": export_format, + "quality_slider": int(export_quality), + "max_width": int(export_max_width) + } + }) + + # Call the batch processor, which now handles both single and batch jobs + final_video_path = tb_processor.tb_process_video_batch(input_paths_to_process, pipeline_config, progress) + + # Return the final video path to the player (will be None for batch, path for single) + return final_video_path, tb_update_messages() + +def tb_update_active_tab_index(evt: gr.SelectData): + if not evt: + return 0 # Default to the first tab (Single Video) if event data is missing + return evt.index + + index = evt.index + tab_name = "Single Video" if index == 0 else "Batch Video" + tb_message_mgr.add_message(f"DEBUG: Active tab changed to -> {tab_name} (Index: {index})") + + # Return the new index for the state and the updated messages + return index, tb_update_messages() + +def tb_handle_upscale_video(video_path, model_key_selected, output_scale_factor_from_slider, tile_size, enhance_face_ui, denoise_strength_from_slider, use_streaming, progress=gr.Progress()): + tb_message_mgr.clear() + if video_path is None: + tb_message_mgr.add_warning("No input video selected for upscaling.") + return None, tb_update_messages() + if not model_key_selected: + tb_message_mgr.add_warning("No upscale model selected.") + return None, tb_update_messages() + + try: + tile_size_int = int(tile_size) + except ValueError: + tb_message_mgr.add_error(f"Invalid tile size value: {tile_size}. Using None (0).") + tile_size_int = 0 + + try: + output_scale_factor_float = float(output_scale_factor_from_slider) + if not (output_scale_factor_float >= 0.25): + tb_message_mgr.add_error(f"Invalid output scale factor: {output_scale_factor_from_slider:.2f}. Must be >= 0.25.") + return None, tb_update_messages() + except ValueError: + tb_message_mgr.add_error(f"Invalid output scale factor: {output_scale_factor_from_slider}. Not a valid number.") + return None, tb_update_messages() + + output_video = tb_processor.tb_upscale_video( + video_path, + model_key_selected, + float(output_scale_factor_from_slider), + int(tile_size), + enhance_face_ui, + denoise_strength_from_slider, + use_streaming, + progress=progress + ) + return output_video, tb_update_messages() + +def tb_get_model_info_and_update_scale_slider(model_key_selected: str): + native_scale = 2.0 + slider_min = 1.0 + slider_max = 2.0 + slider_step = 0.05 + slider_default_value = 2.0 + model_info_text = "Info: Select a model." + slider_label = "Target Upscale Factor" + + denoise_slider_visible = False + denoise_slider_value = 0.5 + + if model_key_selected and model_key_selected in tb_processor.esrgan_upscaler.supported_models: + model_details = tb_processor.esrgan_upscaler.supported_models[model_key_selected] + fetched_native_scale = model_details.get('scale') + description = model_details.get('description', 'No description available.') + + if isinstance(fetched_native_scale, (int, float)) and fetched_native_scale > 0: + native_scale = float(fetched_native_scale) + slider_max = native_scale + slider_default_value = native_scale + slider_min = 1.0 + + if native_scale >= 4.0: slider_step = 0.1 + elif native_scale >= 2.0: slider_step = 0.05 + + model_info_text = f"{description}" + slider_label = f"Target Upscale Factor (Native {native_scale}x)" + + if model_key_selected == "RealESR-general-x4v3": + denoise_slider_visible = True + + model_info_update = gr.update(value=model_info_text) + outscale_slider_update = gr.update( + minimum=slider_min, maximum=slider_max, step=slider_step, + value=slider_default_value, label=slider_label + ) + denoise_slider_update = gr.update( + visible=denoise_slider_visible, value=denoise_slider_value + ) + + return model_info_update, outscale_slider_update, denoise_slider_update + +def tb_get_selected_model_scale_info(model_key_selected): + if model_key_selected and model_key_selected in tb_processor.esrgan_upscaler.supported_models: + model_details = tb_processor.esrgan_upscaler.supported_models[model_key_selected] + scale = model_details.get('N/A') + description = model_details.get('description', 'No description available.') + return f"{description}" + return "Info: Select a model." + +def tb_handle_delete_studio_transformer(): + tb_message_mgr.clear() + tb_message_mgr.add_message("Attempting to directly access and delete Studio transformer...") + print("Attempting to directly access and delete Studio transformer...") + log_messages_from_action = [] + + studio_module_instance = None + if '__main__' in sys.modules and hasattr(sys.modules['__main__'], 'current_generator'): + studio_module_instance = sys.modules['__main__'] + print("Found studio context in __main__.") + elif 'studio' in sys.modules and hasattr(sys.modules['studio'], 'current_generator'): + studio_module_instance = sys.modules['studio'] + print("Found studio context in sys.modules['studio'].") + + if studio_module_instance is None: + print("ERROR: Could not find the 'studio' module's active context.") + tb_message_mgr.add_message("ERROR: Could not find the 'studio' module's active context in sys.modules.") + tb_message_mgr.add_error("Deletion Failed: Studio module context not found.") + return tb_update_messages() + + job_queue_instance = getattr(studio_module_instance, 'job_queue', None) + JobStatus_enum = getattr(studio_module_instance, 'JobStatus', None) + + if job_queue_instance and JobStatus_enum: + current_job_in_queue = getattr(job_queue_instance, 'current_job', None) + if current_job_in_queue and hasattr(current_job_in_queue, 'status') and current_job_in_queue.status == JobStatus_enum.RUNNING: + tb_message_mgr.add_warning("Cannot unload model: A video generation job is currently running.") + tb_message_mgr.add_message("Please wait for the current job to complete or cancel it first using the main interface.") + print("Cannot unload model: A job is currently running in the queue.") + return tb_update_messages() + + generator_object_to_delete = getattr(studio_module_instance, 'current_generator', None) + print(f"Direct access: generator_object_to_delete is {type(generator_object_to_delete)}, id: {id(generator_object_to_delete)}") + + if generator_object_to_delete is not None: + model_name_str = "Unknown Model" + try: + if hasattr(generator_object_to_delete, 'get_model_name') and callable(generator_object_to_delete.get_model_name): + model_name_str = generator_object_to_delete.get_model_name() + elif hasattr(generator_object_to_delete, 'transformer') and generator_object_to_delete.transformer is not None: + model_name_str = generator_object_to_delete.transformer.__class__.__name__ + else: + model_name_str = generator_object_to_delete.__class__.__name__ + except Exception: + pass + + tb_message_mgr.add_message(f" Deletion of '{model_name_str}' initiated.") + log_messages_from_action.append(f" Found active generator: {model_name_str}. Preparing for deletion.") + print(f"Found active generator: {model_name_str}. Preparing for deletion.") + + try: + if hasattr(generator_object_to_delete, 'unload_loras') and callable(generator_object_to_delete.unload_loras): + print(" - LoRAs: Unloading from transformer...") + generator_object_to_delete.unload_loras() + else: + log_messages_from_action.append(" - LoRAs: No unload method found or not applicable.") + + if hasattr(generator_object_to_delete, 'transformer') and generator_object_to_delete.transformer is not None: + transformer_object_ref = generator_object_to_delete.transformer + transformer_name_for_log = transformer_object_ref.__class__.__name__ + print(f" - Transformer ({transformer_name_for_log}): Preparing for memory operations.") + + if hasattr(transformer_object_ref, 'device') and transformer_object_ref.device != cpu: + if hasattr(transformer_object_ref, 'to') and callable(transformer_object_ref.to): + try: + print(f" - Transformer ({transformer_name_for_log}): Moving to CPU...") + transformer_object_ref.to(cpu) + log_messages_from_action.append(" - Transformer moved to CPU.") + print(f" - Transformer ({transformer_name_for_log}): Moved to CPU.") + except Exception as e_cpu: + error_msg_cpu = f" - Transformer ({transformer_name_for_log}): Move to CPU FAILED: {e_cpu}" + log_messages_from_action.append(error_msg_cpu) + print(error_msg_cpu) + else: + log_messages_from_action.append(f" - Transformer ({transformer_name_for_log}): Cannot move to CPU, 'to' method not found.") + print(f" - Transformer ({transformer_name_for_log}): Cannot move to CPU, 'to' method not found.") + elif hasattr(transformer_object_ref, 'device') and transformer_object_ref.device == cpu: + log_messages_from_action.append(" - Transformer already on CPU.") + print(f" - Transformer ({transformer_name_for_log}): Already on CPU.") + else: + log_messages_from_action.append(" - Transformer: Could not determine device or move to CPU.") + print(f" - Transformer ({transformer_name_for_log}): Could not determine device or move to CPU.") + + print(f" - Transformer ({transformer_name_for_log}): Removing attribute from generator...") + generator_object_to_delete.transformer = None + print(f" - Transformer ({transformer_name_for_log}): Deleting Python reference...") + del transformer_object_ref + log_messages_from_action.append(" - Transformer reference deleted.") + print(f" - Transformer ({transformer_name_for_log}): Reference deleted.") + else: + log_messages_from_action.append(" - Transformer: Not found or already unloaded.") + print(" - Transformer: Not found or already unloaded.") + + generator_class_name_for_log = generator_object_to_delete.__class__.__name__ + print(f" - Model Generator ({generator_class_name_for_log}): Setting global reference to None...") + setattr(studio_module_instance, 'current_generator', None) + log_messages_from_action.append(" - 'current_generator' in studio module set to None.") + print(" - Global 'current_generator' in studio module successfully set to None.") + + print(f" - Model Generator ({generator_class_name_for_log}): Deleting local Python reference...") + del generator_object_to_delete + print(f" - Model Generator ({generator_class_name_for_log}): Python reference deleted.") + + print(" - System: Performing garbage collection and CUDA cache clearing.") + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + log_messages_from_action.append(" - GC and CUDA cache cleared.") + print(" - System: GC and CUDA cache clear completed.") + + log_messages_from_action.append(f"✅ Deletion of '{model_name_str}' completed successfully from toolbox.") + tb_message_mgr.add_success(f"Deletion of '{model_name_str}' initiated from toolbox.") + + except Exception as e_del: + error_msg_del = f"Error during deletion process: {e_del}" + log_messages_from_action.append(f" - {error_msg_del}") + print(f" - {error_msg_del}") + traceback.print_exc() + tb_message_mgr.add_error(f"Deletion Error: {e_del}") + else: + tb_message_mgr.add_message("ℹ️ No active generator found. Nothing to delete.") + print("No active generator found via direct access. Nothing to delete.") + + for msg_item in log_messages_from_action: + tb_message_mgr.add_message(msg_item) + + return tb_update_messages() + +def tb_handle_manually_save_video(temp_video_path_from_component): + tb_message_mgr.clear() + if not temp_video_path_from_component: + tb_message_mgr.add_warning("No video in the output player to save.") + return temp_video_path_from_component, tb_update_messages() + + copied_path = tb_processor.tb_copy_video_to_permanent_storage(temp_video_path_from_component) + + if copied_path and os.path.abspath(copied_path) != os.path.abspath(temp_video_path_from_component): + tb_message_mgr.add_success(f"Video successfully copied to permanent storage.") + + return temp_video_path_from_component, tb_update_messages() + +def tb_handle_clear_temp_files(): + tb_message_mgr.clear() + + # The processor function now returns a complete summary string. + cleanup_summary = tb_processor.tb_clear_temporary_files() + + # Add the summary to the message manager to be displayed in the console. + tb_message_mgr.add_message(cleanup_summary) + + # Return None to clear the video player and the updated messages. + return None, tb_update_messages() + +def tb_handle_use_processed_as_input(processed_video_path): + if not processed_video_path: + tb_message_mgr.add_warning("No processed video available to use as input.") + # Return updates for all 3 outputs, changing nothing. + return gr.update(), tb_update_messages(), gr.update() + else: + tb_message_mgr.add_message("Moved processed video to input.") + # Return new value for input, messages, and None to clear analysis. + return processed_video_path, tb_update_messages(), None + +def tb_handle_join_videos(video_files_list, custom_output_name, progress=gr.Progress()): # Add new parameter + tb_message_mgr.clear() + + if not video_files_list: + tb_message_mgr.add_warning("No video files were uploaded to join.") + return None, tb_update_messages() + + video_paths = [file.name for file in video_files_list] + + # Pass the custom name to the processor + output_video = tb_processor.tb_join_videos(video_paths, custom_output_name, progress) + return output_video, tb_update_messages() + +def tb_handle_export_video(video_path, export_format, quality, max_width, custom_name, progress=gr.Progress()): + tb_message_mgr.clear() + if not video_path: + tb_message_mgr.add_warning("No input video in the top-left player to export.") + # Return None for the video player and the message update + return None, tb_update_messages() + + # The input video for this operation is ALWAYS the one in the main input player. + output_file = tb_processor.tb_export_video( + video_path, + export_format, + quality, + max_width, + custom_name, + progress + ) + + # Return the path to the new video and the updated messages. + return output_file, tb_update_messages() + +def tb_get_formatted_toolbar_stats(): + vram_full_str = "VRAM: N/A" + gpu_full_str = "GPU: N/A" + ram_full_str = "RAM: N/A" + + vram_component_visible = False + gpu_component_visible = False + + try: + ram_info_psutil = psutil.virtual_memory() + ram_used_gb = ram_info_psutil.used / (1024**3) + ram_total_gb = ram_info_psutil.total / (1024**3) + ram_full_str = f"RAM: {ram_used_gb:.1f}/{round(ram_total_gb)}GB ({round(ram_info_psutil.percent)}%)" + + if torch.cuda.is_available(): + _, nvidia_metrics, _ = SystemMonitor.get_nvidia_gpu_info() + if nvidia_metrics: + vram_used = nvidia_metrics.get('memory_used_gb', 0.0) + vram_total = nvidia_metrics.get('memory_total_gb', 0.0) + vram_full_str = f"VRAM: {vram_used:.1f}/{round(vram_total)}GB" + vram_component_visible = True + + temp = nvidia_metrics.get('temperature', 0.0) + load = nvidia_metrics.get('utilization', 0.0) + gpu_full_str = f"GPU: {temp:.0f}°C {load:.0f}%" + gpu_component_visible = True + + except Exception as e: + print(f"Error getting system stats values for toolbar (from toolbox_app.py): {e}") + ram_full_str = "RAM: Error" + is_nvidia_expected = torch.cuda.is_available() + if is_nvidia_expected: + vram_full_str = "VRAM: Error" + gpu_full_str = "GPU: Error" + vram_component_visible = True + gpu_component_visible = True + else: + vram_full_str = "VRAM: N/A" + gpu_full_str = "GPU: N/A" + vram_component_visible = False + gpu_component_visible = False + + return ( + gr.update(value=ram_full_str), + gr.update(value=vram_full_str, visible=vram_component_visible), + gr.update(value=gpu_full_str, visible=gpu_component_visible) + ) + +# --- Gradio Interface --- + +def tb_create_video_toolbox_ui(): + initial_autosave_state = settings_instance.get("toolbox_autosave_enabled", True) + tb_processor.set_autosave_mode(initial_autosave_state) + + with gr.Column() as tb_toolbox_ui_main_container: + with gr.Row(): + with gr.Column(scale=1): + # Replace gr.State with a hidden gr.Number for robust state passing + tb_active_tab_index_storage = gr.Number(value=0, visible=False) + + with gr.Tabs(elem_id="toolbox_input_tabs") as tb_input_tabs: + with gr.TabItem("Single Video Input", id=0): + tb_input_video_component = gr.Video( + label="Upload Video for processing", + autoplay=True, + elem_classes="video-size", + elem_id="toolbox-video-player" + ) + with gr.TabItem("Batch Video Input", id=1): + tb_batch_input_files = gr.File( + label="Upload Multiple Videos for Batch Processing", + file_count="multiple", + type="filepath" + ) + tb_start_pipeline_btn = gr.Button("🚀 Start Pipeline Processing", variant="primary", size="sm", elem_id="toolbox-start-pipeline-btn") + + with gr.Column(scale=1): + with gr.Tabs(elem_id="toolbox_output_tabs"): + with gr.TabItem("Video Output"): + tb_processed_video_output = gr.Video( + label="Processed Video", + autoplay=True, + interactive=False, + elem_classes="video-size" + ) + with gr.Row(): + tb_use_processed_as_input_btn = gr.Button("Use as Input", size="sm", scale=4) + tb_manual_save_btn = gr.Button("Manual Save", variant="secondary", size="sm", scale=4, visible=not initial_autosave_state) + + with gr.Row(): + with gr.Column(scale=1): + with gr.Accordion("Processing Pipeline", open=True): + gr.Markdown("Required for batch processing and recommended for single video. ", elem_classes="small-text-info") + with gr.Row(equal_height=False): + with gr.Group(): + tb_pipeline_steps_chkbox = gr.CheckboxGroup( + label="Pipeline Steps:", + choices=["upscale", "frames", "filters", "loop", "export"], + value=[], + info="Select which pre-configured operations to run. Executed in order." + ) + + # --- Right Column: Workflow Presets --- + with gr.Column(scale=1): + with gr.Accordion("Workflow Presets", open=True): + gr.Markdown("Save/load all operation settings and active steps.", elem_classes="small-text-info") + with gr.Row(): + workflow_choices = sorted([k for k in tb_workflow_presets_data.keys() if k != "None"]) + workflow_choices.insert(0, "None") + with gr.Column(scale=1): + tb_workflow_preset_select = gr.Dropdown( + choices=workflow_choices, value="None", label="Load Workflow" + ) + with gr.Column(scale=1): + tb_workflow_preset_name_input = gr.Textbox( + label="Preset Name (for saving)", placeholder="e.g., My Favorite Upscale" + ) + with gr.Group(): + with gr.Row(): + tb_workflow_save_btn = gr.Button("💾 Save/Update", size="sm", variant="primary") + tb_workflow_delete_btn = gr.Button("🗑️ Delete", size="sm", variant="stop") + tb_workflow_reset_btn = gr.Button("🔄 Reset All to Defaults", size="sm") + + with gr.Row(): + with gr.Column(): + with gr.Group(): + tb_analyze_button = gr.Button("Click to Analyze Input Video", size="sm", variant="huggingface") + with gr.Accordion("Video Analysis Results", open=False) as tb_analysis_accordion: + tb_video_analysis_output = gr.Textbox( + container=False, lines=10, show_label=False, + interactive=False, elem_classes="analysis-box", + ) + + with gr.Column(): + with gr.Group(): + with gr.Row(): + tb_monitor_toggle_checkbox = gr.Checkbox(label="Live System Monitoring", value=False) + tb_autosave_checkbox = gr.Checkbox(label="Autosave", value=initial_autosave_state) + tb_resource_monitor_output = gr.Textbox( + show_label=False, container=False, max_lines=8, + interactive=False, visible=False, + ) + with gr.Row(): + tb_delete_studio_transformer_btn = gr.Button("Click to Unload Studio Model", size="sm", scale=3, variant="stop") + + with gr.Accordion("Operations", open=True): + with gr.Tabs(): + with gr.TabItem("📈 Upscale Video (ESRGAN)"): + with gr.Row(): + gr.Markdown("Upscale video resolution using Real-ESRGAN.") + with gr.Row(): + with gr.Column(scale=2): + tb_upscale_model_select = gr.Dropdown( + choices=list(tb_processor.esrgan_upscaler.supported_models.keys()), + value=list(tb_processor.esrgan_upscaler.supported_models.keys())[0] if tb_processor.esrgan_upscaler.supported_models else None, + label="ESRGAN Model", + info="Select the Real-ESRGAN model." + ) + default_model_key_init = list(tb_processor.esrgan_upscaler.supported_models.keys())[0] if tb_processor.esrgan_upscaler.supported_models else None + initial_model_info_gr_val, initial_slider_gr_val, initial_denoise_gr_val = tb_get_model_info_and_update_scale_slider(default_model_key_init) + + tb_selected_model_scale_display = gr.Textbox( + label="Selected Model Info", + value=initial_model_info_gr_val.get('value', "Info: Select a model."), + interactive=False, + lines=2 + ) + + tb_upscale_factor_slider = gr.Slider( + minimum=initial_slider_gr_val.get('minimum', 1.0), + maximum=initial_slider_gr_val.get('maximum', 2.0), + step=initial_slider_gr_val.get('step', 0.05), + value=initial_slider_gr_val.get('value', 2.0), + label=initial_slider_gr_val.get('label', "Target Upscale Factor"), + info="Desired output scale (e.g., 2.0 for 2x). Video is upscaled by the model, then resized if this differs from native scale." + ) + with gr.Column(scale=2): + tb_upscale_tile_size_radio = gr.Radio( + choices=[("None (Recommended)", 0), ("512px", 512), ("256px", 256)], + value=0, label="Tile Size for Upscaling", + info="Splits video frames into tiles for processing. 'None' disables tiling. Smaller values (e.g., 512, 256) use less VRAM but are slower and can potentially show seams on some videos. Use if 'None' causes Out-Of-Memory." + ) + with gr.Row(): + tb_upscale_enhance_face_checkbox = gr.Checkbox( + label="Enhance Faces (GFPGAN)", value=False, + info="Uses GFPGAN to restore (human-like) faces. Increases processing time." + ) + with gr.Row(): + tb_upscale_use_streaming_checkbox = gr.Checkbox( + label="Use Streaming (Low Memory Mode)", value=False, + info="Enable for stable, low-memory processing of long or high-res videos. This avoids loading the entire clip into RAM, making it ideal for 4K footage or very large files." + ) + with gr.Row(): + tb_denoise_strength_slider = gr.Slider( + label="Denoise Strength (for RealESR-general-x4v3)", + minimum=0.0, maximum=1.0, step=0.01, + value=initial_denoise_gr_val.get('value', 0.5), + info="Adjusts denoising for RealESR-general-x4v3. 0.0=Max WDN, <1.0=Blend, 1.0=No WDN.", + visible=initial_denoise_gr_val.get('visible', False), + interactive=True + ) + with gr.Row(): + tb_upscale_video_btn = gr.Button("🚀 Upscale Video", variant="primary") + + with gr.TabItem("🎞️ Frame Adjust (Speed & Interpolation)"): + with gr.Row(): + gr.Markdown("Adjust video speed and interpolate frames using RIFE AI.") + with gr.Row(): + tb_process_fps_mode = gr.Radio( + choices=["No Interpolation", "2x Frames", "4x Frames"], value="No Interpolation", label="RIFE Frame Interpolation", + info="Select '2x' or '4x' RIFE Interpolation to double or quadruple the frame rate, creating smoother motion. 4x is more intensive and runs the 2x process twice." + ) + tb_frames_use_streaming_checkbox = gr.Checkbox( + label="Use Streaming (Low Memory Mode)", value=False, + info="Enable for stable, low-memory RIFE on long videos. This avoids loading all frames into RAM. Note: 'Adjust Video Speed' is ignored in this mode." + ) + with gr.Row(): + tb_process_speed_factor = gr.Slider( + minimum=0.25, maximum=4.0, step=0.05, value=1.0, label="Adjust Video Speed Factor", + info="Values < 1.0 slow down the video, values > 1.0 speed it up. Affects video and audio." + ) + + tb_process_frames_btn = gr.Button("🚀 Process Frames", variant="primary") + + with gr.TabItem("🎨 Video Filters (FFmpeg)"): + with gr.Row(): + gr.Markdown("Apply visual enhancements using FFmpeg filters.") + with gr.Row(): + tb_filter_brightness = gr.Slider(-100, 100, value=TB_DEFAULT_FILTER_SETTINGS["brightness"], step=1, label="Brightness (%)", info="Adjusts overall image brightness.") + tb_filter_contrast = gr.Slider(0, 3, value=TB_DEFAULT_FILTER_SETTINGS["contrast"], step=0.05, label="Contrast (Linear)", info="Increases/decreases difference between light/dark areas.") + with gr.Row(): + tb_filter_saturation = gr.Slider(0, 3, value=TB_DEFAULT_FILTER_SETTINGS["saturation"], step=0.05, label="Saturation", info="Adjusts color intensity. 0=grayscale, 1=original.") + tb_filter_temperature = gr.Slider(-100, 100, value=TB_DEFAULT_FILTER_SETTINGS["temperature"], step=1, label="Color Temperature Adjust", info="Shifts colors towards orange (warm) or blue (cool).") + with gr.Row(): + tb_filter_sharpen = gr.Slider(0, 5, value=TB_DEFAULT_FILTER_SETTINGS["sharpen"], step=0.1, label="Sharpen Strength", info="Enhances edge details. Use sparingly.") + tb_filter_blur = gr.Slider(0, 5, value=TB_DEFAULT_FILTER_SETTINGS["blur"], step=0.1, label="Blur Strength", info="Softens the image.") + with gr.Row(): + tb_filter_denoise = gr.Slider(0, 10, value=TB_DEFAULT_FILTER_SETTINGS["denoise"], step=0.1, label="Denoise Strength", info="Reduces video noise/grain.") + tb_filter_vignette = gr.Slider(0, 100, value=TB_DEFAULT_FILTER_SETTINGS["vignette"], step=1, label="Vignette Strength (%)", info="Darkens corners, drawing focus to center.") + with gr.Row(): + tb_filter_s_curve_contrast = gr.Slider(0, 100, value=TB_DEFAULT_FILTER_SETTINGS["s_curve_contrast"], step=1, label="S-Curve Contrast", info="Non-linear contrast, boosting highlights/shadows subtly.") + tb_filter_film_grain_strength = gr.Slider(0, 50, value=TB_DEFAULT_FILTER_SETTINGS["film_grain_strength"], step=1, label="Film Grain Strength", info="Adds artificial film grain.") + + tb_apply_filters_btn = gr.Button("✨ Apply Filters to Video", variant="primary") + + with gr.Row(equal_height=False): + with gr.Column(scale=2): + with gr.Row(): + preset_choices = list(tb_filter_presets_data.keys()) if tb_filter_presets_data else ["none"] + if "none" not in preset_choices and preset_choices: + preset_choices.insert(0,"none") + elif not preset_choices: + preset_choices = ["none"] + + tb_filter_preset_select = gr.Dropdown(choices=preset_choices, value="none", label="Load Preset", scale=2) + tb_new_preset_name_input = gr.Textbox(label="Preset Name (for saving/editing)", placeholder="Select preset or type new name...", scale=2) + with gr.Column(scale=1): + with gr.Row(): + tb_save_preset_btn = gr.Button("💾 Save/Update", variant="primary", scale=1) + tb_delete_preset_btn = gr.Button("🗑️ Delete", variant="stop", scale=1) + with gr.Row(): + tb_reset_filters_btn = gr.Button("🔄 Reset All Sliders to 'None' Preset") + + with gr.TabItem("🔄 Video Loop"): + with gr.Row(): + gr.Markdown("Create looped or ping-pong versions of the video.") + + tb_loop_type_select = gr.Radio(choices=["loop", "ping-pong"], value="loop", label="Loop Type") + tb_num_loops_slider = gr.Slider( + minimum=1, maximum=10, step=1, value=1, label="Number of Loops/Repeats", + info="The video will play its original content, then repeat this many additional times. E.g., 1 loop = 2 total plays of the segment." + ) + tb_create_loop_btn = gr.Button("🔁 Create Loop", variant="primary") + + with gr.TabItem("🖼️ Frames Studio"): + with gr.Column(): + gr.Markdown("### 1. Extract Frames from Video") + gr.Markdown( + "⚠️ **Warning:** Extracting frames from high-resolution (e.g., 4K+) or long videos can consume a significant amount of disk space (many gigabytes) and may cause the Frames Studio gallery to load slowly. Proceed with caution." + ) + gr.Markdown("Extract frames from the **uploaded video (top-left)** as images. These folders can then be loaded into the Frames Studio below.") + with gr.Row(): + tb_extract_rate_slider = gr.Number( + label="Extract Every Nth Frame", value=1, minimum=1, step=1, + info="1 = all frames. N = 1st, (N+1)th... (i.e., frame 0, frame N, frame 2N, etc.)", + scale=1 + ) + tb_extract_frames_btn = gr.Button("🔨 Extract Frames", variant="primary", scale=2) + + gr.Markdown("---") + + with gr.Column(): + gr.Markdown("### 2. Frames Studio") + gr.Markdown("Load an extracted frames folder to view, delete, and manage individual frames before reassembling.") + with gr.Row(): + tb_extracted_folders_dropdown = gr.Dropdown( + label="Select Extracted Folder to Load", + info="Select a folder from your 'extracted_frames' directory.", + scale=3 + ) + tb_refresh_extracted_folders_btn = gr.Button("🔄 Refresh List", scale=1) + tb_clear_selected_folder_btn = gr.Button( + "🗑️ Delete ENTIRE Folder", variant="stop", interactive=False, scale=1 + ) + tb_load_frames_to_studio_btn = gr.Button("🖼️ Load Frames to Studio", variant="secondary") + + # Redesigned Studio Area + with gr.Column(variant="panel"): + with gr.Column(elem_id="gallery-scroll-wrapper"): + tb_frames_gallery = gr.Gallery( + label="Extracted Frames", show_label=False, elem_id="toolbox_frames_gallery", + columns=8, # height is now controlled by the wrapper's CSS + object_fit="contain", preview=False + ) + with gr.Row(): + with gr.Column(scale=1, min_width=220): + tb_save_selected_frame_btn = gr.Button("💾 Save Selected Frame", size="sm", interactive=False) + tb_delete_selected_frame_btn = gr.Button("🗑️ Delete Selected Frame", size="sm", variant="stop", interactive=False) + with gr.Column(scale=3): + # This row now contains the info box and the new clear button + with gr.Row(): + tb_frame_info_box = gr.Textbox( + # label="Selected Frame Info", + interactive=False, + placeholder="Click a frame in the gallery above to select it.", + container=False, + lines=2, + scale=4 + ) + tb_clear_gallery_btn = gr.Button("🧹 Clear Gallery", size="sm", scale=1) + + gr.Markdown("---") + + with gr.Column(): + gr.Markdown("### 3. Reassemble Frames to Video") + gr.Markdown("After you are satisfied with the frames in the studio, reassemble them into a new video.") + with gr.Row(): + tb_reassemble_output_fps = gr.Number(label="Output Video FPS", value=30, minimum=1, step=1) + tb_reassemble_video_name_input = gr.Textbox(label="Output Video Name (optional, .mp4 added)", placeholder="e.g., my_edited_video") + tb_reassemble_frames_btn = gr.Button("🧩 Reassemble From Studio", variant="primary") + + with gr.TabItem("🧩 Join Videos (Concatenate)"): + with gr.Accordion("Select two or more videos to join them together into a single file", open=True): + gr.Markdown( + """ + * **Input:** The Input accepts multiple videos dragged in or ctrl+clicked via `Click to Upload`**. + * **Output:** The result will appear in the **'Processed Video' player (top-right)** for you to review. + * **Saving:** The output is saved to your 'saved_videos' folder if 'Autosave' is enabled. Otherwise, you must click the 'Manual Save' button. + """ + ) + tb_join_videos_input = gr.File( + label="Upload Videos to Join", + file_count="multiple", + file_types=["video", "file"] + ) + + tb_join_video_name_input = gr.Textbox( + label="Output Video Name (optional, .mp4 and timestamp added)", + placeholder="e.g., my_awesome_compilation" + ) + + tb_join_videos_btn = gr.Button("🤝 Join Videos", variant="primary") + + with gr.TabItem("📦 Export & Compress"): + with gr.Accordion("Compress your final video and/or convert it into a shareable format", open=True): + gr.Markdown( + """ + * **Input:** This operation always uses the video in the **'Upload Video' player (top-left)**. + * **Output:** The result will appear in the **'Processed Video' player (top-right)** for you to review. + * **Saving:** The output is saved to your 'saved_videos' folder if 'Autosave' is enabled. Otherwise, you must click the 'Manual Save' button. Note: GIFs will _always_ be saved! + * **Note:** WebM and GIF encoding can be slow for long or high-resolution videos. Please be patient! + """ + ) + with gr.Row(): + with gr.Column(scale=2): + tb_export_format_radio = gr.Radio( + ["MP4", "WebM", "GIF"], value="MP4", label="Output Format", + info="MP4 is best for general use. WebM is great for web/Discord (smaller size). GIF is a widely-supported format for short, silent, looping clips. GIF output will always be saved." + ) + tb_export_quality_slider = gr.Slider( + 0, 100, value=85, step=1, label="Quality", + info="Higher quality means a larger file size. 80-90 is a good balance for MP4/WebM." + ) + with gr.Column(scale=2): + tb_export_resize_slider = gr.Slider( + 256, 2048, value=1024, step=64, label="Max Width (pixels)", + info="Resizes the video to this maximum width while maintaining aspect ratio. A powerful way to reduce file size." + ) + tb_export_name_input = gr.Textbox( + label="Output Filename (optional)", + placeholder="e.g., my_final_video_for_discord" + ) + + tb_export_video_btn = gr.Button("🚀 Export Video", variant="primary") + + with gr.Accordion("💡 Post-processing Guide & Tips", open=False): + with gr.Tabs(): + with gr.TabItem("🚀 Getting Started"): + gr.Markdown(""" + ### Welcome to the Toolbox! + + **1. Input & Output** + * Most tools use the video in the **Upload Video player ⬅️ (top-left)** as their input. + * Your results will appear in the **Processed Video player ➡️ (top-right)**. + + **2. Chaining Operations (Applying multiple effects)** + * To use a result as the input for your next step: + 1. Run an operation (like Upscale). + 2. When the result appears, click the **'Use as Input'** button. + 3. Your result is now in the input player, ready for the next operation! + + **3. Saving Your Work** + * **Autosave:** When the `Autosave` checkbox is on, all results are automatically saved to the `saved_videos` folder. + * **Manual Save:** If `Autosave` is off, results go to a temporary folder. Use the **'💾 Manual Save'** button to save the video from the output player permanently. + + **4. Analyze First!** + * It's a good idea to click the **'Analyze Video'** button after uploading. It gives you helpful info like resolution and frame rate. + """) + + with gr.TabItem("⛓️ The Processing Pipeline"): + gr.Markdown(""" + ### Run Multiple Operations at Once + The pipeline lets you set up a series of operations and run them with a single click. This is the main way to process videos. + + **How to Use the Pipeline:** + 1. **Configure:** Go to the operation tabs (📈 Upscale, 🎨 Filters, etc.) and set the sliders and options exactly how you want them. + 2. **Select:** In the **'Processing Pipeline'** section, check the boxes for the steps you want to run. + 3. **Input:** Make sure your video is in the 'Single Video' tab, or your files are in the 'Batch Video' tab. + 4. **Execute:** Click the **'🚀 Start Pipeline Processing'** button. + + **Execution Order:** + The pipeline always runs in this fixed order, no matter when you check the boxes: + `Upscale` ➡️ `Frame Adjust` ➡️ `Filters` ➡️ `Loop` ➡️ `Export` + + **Single vs. Batch Video:** + * **Single Video:** Processes one video. The final result will appear in the output player. + * **Batch Video:** Processes multiple videos. Each video will go through the entire pipeline. Outputs are saved directly to a new, timestamped folder inside `saved_videos`. The output player will only show the very last video processed. + + **Workflow Presets:** + * Use presets to **save and load your entire pipeline setup**, including all slider values and selected steps. + """) + + with gr.TabItem("🖼️ Frames Studio Workflow"): + gr.Markdown(""" + ### Edit Your Video Frame-by-Frame + The Frames Studio lets you break a video into images, edit them, and put them back together. + + **Step 1: Extract Frames** + * Upload a video and use the **'🔨 Extract Frames'** button. + * This creates a new folder of images in `postprocessed_output/frames/extracted_frames/`. + * ⚠️ **Warning:** Extracting from long or high-res videos can use a lot of disk space! + + **Step 2: Edit in the Studio** + * Click **'🔄 Refresh List'** to find your new folder, then click **'🖼️ Load Frames to Studio'**. + * The frames will appear in the gallery. Click a frame to select it. + * **Delete Frames:** Use the **'🗑️ Delete Selected Frame'** button to remove bad frames or glitches. + * **Save Frames:** Use the **'💾 Save Selected Frame'** button to save a high-quality copy of a single frame. Perfect for use as an image prompt! + + **Step 3: Reassemble Video** + * Once you're done editing, use the **'🧩 Reassemble From Studio'** button. + * This creates a new video using only the frames that are left in the folder. + """) + + with gr.TabItem("⚙️ Other Tools & Tips"): + gr.Markdown(""" + ### Individual Operations + * **🧩 Join Videos:** Combine multiple video clips into a single video file. The tool will automatically handle different resolutions and audio. + * **📦 Export & Compress:** A powerful tool to make your final video smaller. You can lower the quality, resize the video, or convert it to `MP4`, `WebM`, or a silent `GIF`. + + ### Memory Management + * The **'📤 Unload Studio Model'** button can free up VRAM by removing the main video generation model from memory. + * This is useful before running a heavy task here, like a 4K video upscale. The main app will reload the model automatically when you need it again. + + ### Streaming Mode for Upscale & RIFE + * On the **'Upscale'** and **'Frame Adjust'** tabs, you'll find a checkbox: **"Use Streaming (Low Memory Mode)"**. + + **What It Does for You:** + Normally, your entire video is loaded into RAM to process it as fast as possible. For very long or high-resolution videos (like 4K), this can potentially cause it to exceed your RAM and spill over to disk (pagefile) or possibly even cause a system crash! + Streaming Mode processes your video one frame at a time to keep memory usage low and stable. + * **Check this box if you are working with a large video file.** + + **How it Works:** + * **Default Mode:** Loads the entire video into RAM. It's the fastest option but uses the most memory. + * **Streaming Mode (Upscaling & 2x RIFE):** A "true" stream that reads and writes one frame at a time. Memory usage is very low and constant. + * **Streaming Mode (4x RIFE):** A "hybrid" mode. **Be aware: the first 2x pass will still use a large amount of RAM to build the intermediate video (similar to the Default Mode).** However, its key benefit is that the second 2x pass becomes completely stable, preventing the final, largest memory spike that often causes crashes in the default mode. + * **Note:** The **Adjust Video Speed Factor** is ignored when Streaming mode is activated. In Low Memory Mode, this must be done as a separate operation. + + **⭐ Tip for Maximum Memory Savings on 4x RIFE:** + For the absolute lowest memory usage on a 4x interpolation, you can run the **2x Streaming** operation twice back-to-back. + 1. Run a **2x RIFE** with **Streaming Mode enabled**. + 2. Click **"Use as Input"** to move the result back to the input player. + 3. Run a **2x RIFE** on that new video, again with **Streaming Mode enabled**. + This manual two-pass method ensures memory usage never exceeds the "true" streaming level, at the cost of being slower due to writing an intermediate file to disk. + + ### 👇 Check Console Messages! + * The text box at the very bottom of the page shows important status updates, warnings, and error messages. If something isn't working, the answer is probably there! + """) + + with gr.Row(): + tb_message_output = gr.Textbox(label="Console Messages", lines=10, interactive=False, elem_classes="message-box", value=tb_update_messages) + with gr.Row(): + tb_open_folder_button = gr.Button("📁 Open Output Folder", scale=4) + tb_clear_temp_button = gr.Button("🗑️ Clear Temporary Files", variant="stop", scale=1) + + # --- Event Handlers --- + + _ORDERED_FILTER_SLIDERS_ = [ + tb_filter_brightness, tb_filter_contrast, tb_filter_saturation, tb_filter_temperature, + tb_filter_sharpen, tb_filter_blur, tb_filter_denoise, tb_filter_vignette, + tb_filter_s_curve_contrast, tb_filter_film_grain_strength + ] + + # A list of all operation parameter components in the correct order for workflow presets + _ALL_PIPELINE_PARAMS_COMPONENTS_ = [ + # Upscale + tb_upscale_model_select, tb_upscale_factor_slider, tb_upscale_tile_size_radio, + tb_upscale_enhance_face_checkbox, tb_denoise_strength_slider, + tb_upscale_use_streaming_checkbox, # <-- ADDED UPSCALE STREAMING CHECKBOX + # Frame Adjust + tb_process_fps_mode, tb_process_speed_factor, + tb_frames_use_streaming_checkbox, # <-- ADDED FRAME ADJUST STREAMING CHECKBOX + # Loop + tb_loop_type_select, tb_num_loops_slider, + # Filters + *_ORDERED_FILTER_SLIDERS_, + # Export + tb_export_format_radio, tb_export_quality_slider, tb_export_resize_slider + ] + + # The list of all inputs for the main pipeline execution + _ALL_PIPELINE_INPUTS_ = [ + tb_active_tab_index_storage, + tb_pipeline_steps_chkbox, + # Inputs + tb_input_video_component, tb_batch_input_files, + # Parameters + *_ALL_PIPELINE_PARAMS_COMPONENTS_ + ] + + # --- NEW: Workflow Preset Event Handlers --- + tb_workflow_save_btn.click( + fn=tb_handle_save_workflow_preset, + inputs=[tb_workflow_preset_name_input, tb_pipeline_steps_chkbox, *_ALL_PIPELINE_PARAMS_COMPONENTS_], + outputs=[tb_workflow_preset_select, tb_workflow_preset_name_input, tb_message_output] + ) + + # The list of outputs for loading must include the name box, then ALL controls in the correct order + _WORKFLOW_LOAD_OUTPUTS_ = [ + tb_workflow_preset_name_input, # First output is the name box + tb_pipeline_steps_chkbox, # Second is the checkbox group + *_ALL_PIPELINE_PARAMS_COMPONENTS_, # Then all the parameter controls + tb_message_output # Finally, the message box + ] + tb_workflow_preset_select.change( + fn=tb_handle_load_workflow_preset, + inputs=[tb_workflow_preset_select], + outputs=_WORKFLOW_LOAD_OUTPUTS_ + ) + tb_workflow_delete_btn.click( + fn=tb_handle_delete_workflow_preset, + inputs=[tb_workflow_preset_name_input], + # This list now also needs the dropdown prepended + outputs=[tb_workflow_preset_select, *_WORKFLOW_LOAD_OUTPUTS_] + ) + tb_workflow_reset_btn.click( + fn=tb_handle_reset_workflow_to_defaults, + inputs=None, + # The outputs list now starts with the dropdown, followed by the standard load outputs + outputs=[tb_workflow_preset_select, *_WORKFLOW_LOAD_OUTPUTS_] + ) + # --- End Workflow Preset Handlers --- + + tb_start_pipeline_btn.click( + fn=tb_handle_start_pipeline, + inputs=_ALL_PIPELINE_INPUTS_, + outputs=[tb_processed_video_output, tb_message_output] + ) + # Listen for tab changes and update the state component + tb_input_tabs.select( + fn=tb_update_active_tab_index, + inputs=None, # evt is passed implicitly + outputs=[tb_active_tab_index_storage] # Only update the state, not the message box + ) + + # --- SINGLE VIDEO HANDLERS --- + tb_input_video_component.upload(fn=lambda: (tb_message_mgr.clear() or tb_update_messages(), None), outputs=[tb_message_output, tb_video_analysis_output]) + tb_input_video_component.clear(fn=lambda: (tb_message_mgr.clear() or tb_update_messages(), None, None), outputs=[tb_message_output, tb_video_analysis_output, tb_processed_video_output]) + + tb_analyze_button.click( + fn=tb_handle_analyze_video, + inputs=[tb_input_video_component], + outputs=[tb_message_output, tb_video_analysis_output, tb_analysis_accordion] + ) + tb_process_frames_btn.click( + fn=tb_handle_process_frames, + inputs=[tb_input_video_component, tb_process_fps_mode, tb_process_speed_factor, tb_frames_use_streaming_checkbox], # <-- ADDED HERE + outputs=[tb_processed_video_output, tb_message_output] + ) + + tb_create_loop_btn.click(fn=tb_handle_create_loop, inputs=[tb_input_video_component, tb_loop_type_select, tb_num_loops_slider], outputs=[tb_processed_video_output, tb_message_output]) + + tb_filter_preset_select.change( + fn=lambda preset_name_from_dropdown: (preset_name_from_dropdown, *tb_update_filter_sliders_from_preset(preset_name_from_dropdown)), + inputs=[tb_filter_preset_select], outputs=[tb_new_preset_name_input] + _ORDERED_FILTER_SLIDERS_ + ) + tb_apply_filters_btn.click(fn=tb_handle_apply_filters, inputs=[tb_input_video_component] + _ORDERED_FILTER_SLIDERS_, outputs=[tb_processed_video_output, tb_message_output]) + tb_save_preset_btn.click(fn=tb_handle_save_user_preset, inputs=[tb_new_preset_name_input] + _ORDERED_FILTER_SLIDERS_, outputs=[tb_filter_preset_select, tb_message_output, tb_new_preset_name_input]) + tb_delete_preset_btn.click(fn=tb_handle_delete_user_preset, inputs=[tb_new_preset_name_input], outputs=[tb_filter_preset_select, tb_message_output, tb_new_preset_name_input] + _ORDERED_FILTER_SLIDERS_) + tb_reset_filters_btn.click(fn=tb_handle_reset_all_filters, inputs=None, outputs=[tb_filter_preset_select, tb_new_preset_name_input, *_ORDERED_FILTER_SLIDERS_, tb_message_output]) + + tb_use_processed_as_input_btn.click( + fn=tb_handle_use_processed_as_input, + inputs=[tb_processed_video_output], + outputs=[tb_input_video_component, tb_message_output, tb_video_analysis_output] + ) + + tb_upscale_video_btn.click( + fn=tb_handle_upscale_video, + inputs=[tb_input_video_component, tb_upscale_model_select, tb_upscale_factor_slider, tb_upscale_tile_size_radio, tb_upscale_enhance_face_checkbox, tb_denoise_strength_slider, tb_upscale_use_streaming_checkbox], + outputs=[tb_processed_video_output, tb_message_output] + ) + tb_upscale_model_select.change( + fn=tb_get_model_info_and_update_scale_slider, inputs=[tb_upscale_model_select], + outputs=[tb_selected_model_scale_display, tb_upscale_factor_slider, tb_denoise_strength_slider] + ) + + # --- Frames Studio Event Handlers --- + tb_extract_frames_btn.click( + fn=tb_handle_extract_frames, inputs=[tb_input_video_component, tb_extract_rate_slider], outputs=[tb_message_output] + ).then( + fn=tb_handle_refresh_extracted_folders, inputs=None, + outputs=[tb_extracted_folders_dropdown, tb_message_output, tb_clear_selected_folder_btn, tb_frames_gallery, tb_frame_info_box] + ) + tb_refresh_extracted_folders_btn.click( + fn=tb_handle_refresh_extracted_folders, inputs=None, + outputs=[tb_extracted_folders_dropdown, tb_message_output, tb_clear_selected_folder_btn, tb_frames_gallery, tb_frame_info_box] + ) + tb_extracted_folders_dropdown.change( + fn=lambda selection: gr.update(interactive=bool(selection)), + inputs=[tb_extracted_folders_dropdown], outputs=[tb_clear_selected_folder_btn] + ) + tb_clear_selected_folder_btn.click( + fn=tb_handle_clear_selected_folder, + inputs=[tb_extracted_folders_dropdown], + outputs=[ + tb_message_output, + tb_extracted_folders_dropdown, + tb_frames_gallery, + tb_frame_info_box, + tb_save_selected_frame_btn, + tb_delete_selected_frame_btn + ] + ).then( + fn=lambda selection: gr.update(interactive=bool(selection)), + inputs=[tb_extracted_folders_dropdown], + outputs=[tb_clear_selected_folder_btn] + ) + + tb_load_frames_to_studio_btn.click( + fn=tb_handle_load_frames_to_studio, + inputs=[tb_extracted_folders_dropdown], + outputs=[tb_message_output, tb_frames_gallery, tb_frame_info_box] + ) + tb_frames_gallery.select( + fn=tb_handle_frame_select, + inputs=None, # evt_data is passed implicitly + outputs=[tb_frame_info_box, tb_save_selected_frame_btn, tb_delete_selected_frame_btn] + ) + tb_delete_selected_frame_btn.click( + fn=tb_handle_delete_and_refresh_gallery, + inputs=[tb_extracted_folders_dropdown, tb_frame_info_box], + outputs=[ + tb_frames_gallery, + tb_frame_info_box, + tb_save_selected_frame_btn, + tb_delete_selected_frame_btn, + tb_message_output + ] + ) + tb_save_selected_frame_btn.click( + fn=tb_handle_save_selected_frame, + inputs=[tb_extracted_folders_dropdown, tb_frame_info_box], + outputs=[tb_message_output, tb_frame_info_box] + ) + tb_clear_gallery_btn.click( + fn=lambda: (None, "Click a frame in the gallery above to select it.", gr.update(interactive=False), gr.update(interactive=False)), + inputs=None, + outputs=[ + tb_frames_gallery, + tb_frame_info_box, + tb_save_selected_frame_btn, + tb_delete_selected_frame_btn + ] + ) + tb_reassemble_frames_btn.click( + fn=tb_handle_reassemble_frames, + inputs=[tb_extracted_folders_dropdown, tb_reassemble_output_fps, tb_reassemble_video_name_input], + outputs=[tb_processed_video_output, tb_message_output] + ) + + tb_join_videos_btn.click( + fn=tb_handle_join_videos, + # Add the new textbox to the inputs list + inputs=[tb_join_videos_input, tb_join_video_name_input], + outputs=[tb_processed_video_output, tb_message_output] + ) + + tb_export_video_btn.click( + fn=tb_handle_export_video, + inputs=[ + tb_input_video_component, # The video to process + tb_export_format_radio, + tb_export_quality_slider, + tb_export_resize_slider, + tb_export_name_input + ], + # The outputs now include the video player. + outputs=[tb_processed_video_output, tb_message_output] + ) + + # --- Other System Handlers --- + tb_open_folder_button.click(fn=lambda: tb_processor.tb_open_output_folder() or tb_update_messages(), outputs=[tb_message_output]) + tb_monitor_toggle_checkbox.change(fn=lambda is_enabled: gr.update(visible=is_enabled), inputs=[tb_monitor_toggle_checkbox], outputs=[tb_resource_monitor_output]) + tb_monitor_timer = gr.Timer(2, active=True) + tb_monitor_timer.tick(fn=tb_handle_update_monitor, inputs=[tb_monitor_toggle_checkbox], outputs=[tb_resource_monitor_output]) + tb_delete_studio_transformer_btn.click(fn=tb_handle_delete_studio_transformer, inputs=[], outputs=[tb_message_output]) + tb_manual_save_btn.click(fn=tb_handle_manually_save_video, inputs=[tb_processed_video_output], outputs=[tb_processed_video_output, tb_message_output]) + + def tb_handle_autosave_toggle(autosave_is_on_ui_value): + settings_instance.set("toolbox_autosave_enabled", autosave_is_on_ui_value) + tb_processor.set_autosave_mode(autosave_is_on_ui_value) + return { + tb_manual_save_btn: gr.update(visible=not autosave_is_on_ui_value), + tb_message_output: gr.update(value=tb_update_messages()) + } + tb_autosave_checkbox.change(fn=tb_handle_autosave_toggle, inputs=[tb_autosave_checkbox], outputs=[tb_manual_save_btn, tb_message_output]) + tb_clear_temp_button.click(fn=tb_handle_clear_temp_files, inputs=None, outputs=[tb_processed_video_output, tb_message_output]) + + return tb_toolbox_ui_main_container, tb_input_video_component + + +# --- Main execution block for standalone mode --- + +if __name__ == "__main__": + import argparse + + def launch_standalone(): + """Creates and launches the Gradio interface for the toolbox when run as a script.""" + + # 1. Setup and parse command-line arguments, similar to studio.py + parser = argparse.ArgumentParser(description="Run FramePack Toolbox in Standalone Mode") + parser.add_argument('--share', action='store_true', help="Enable Gradio sharing link") + parser.add_argument("--server", type=str, default='127.0.0.1', help="Server name to launch on (default: 127.0.0.1)") + parser.add_argument("--port", type=int, required=False, help="Server port to launch on (default: 7860)") + parser.add_argument("--inbrowser", action='store_true', help="Automatically open in browser") + args = parser.parse_args() + + # 2. Define custom CSS + css = """ + /* hide the gr.Video source selection bar for tb_input_video_component */ + #toolbox-video-player .source-selection { + display: none !important; + } + /* control sizing for gr.Video components */ + .video-size video { + max-height: 60vh; + min-height: 300px !important; + object-fit: contain; + } + /* NEW: Closes the gap between input tabs and the pipeline accordion below them */ + #pipeline-controls-wrapper { + margin-top: -15px !important; /* Adjust this value to get the perfect "snug" fit */ + } + /* --- NEW CSS RULE FOR GALLERY SCROLLING --- */ + #gallery-scroll-wrapper { + max-height: 600px; /* Set your desired fixed height */ + overflow-y: auto; /* Add a scrollbar only when needed */ + } + /* --- --- */ + #toolbox-start-pipeline-btn { + margin-top: -14px !important; /* Adjust this value to get the perfect alignment */ + } + .small-text-info { + font-size: 0.6rem !important; /* Start with a more reasonable size */ + } + """ + + # 3. Get the output directory path from the existing settings instance + output_dir_from_settings = settings_instance.get("output_dir") + allowed_paths = [output_dir_from_settings] + print(f"Gradio server will be allowed to access path: {output_dir_from_settings}") + + # 4. Create the Gradio interface + with gr.Blocks(title="FramePack Toolbox (Standalone)", css=css) as block: + gr.Markdown("# FramePack Post-processing Toolbox (Standalone Mode)") + gr.Markdown( + "This is the standalone version of the toolbox. " + "Upload a video to the 'Upload Video' component to begin. " + "The 'Unload Studio Model' button will have no effect in this mode." + ) + + tb_create_video_toolbox_ui() + print(f"Launching Toolbox server. Access it at http://{args.server}:{args.port if args.port else 7860}") + block.launch( + server_name=args.server, + + # 5. Launch the Gradio app with all the configured arguments + server_port=args.port, + share=args.share, + inbrowser=args.inbrowser, + allowed_paths=allowed_paths + ) + + # Call the launch function + launch_standalone() \ No newline at end of file diff --git a/modules/version.py b/modules/version.py new file mode 100644 index 0000000000000000000000000000000000000000..7939d2fc1a0c1fcb75a62ae8abffbf7fb24e0f80 --- /dev/null +++ b/modules/version.py @@ -0,0 +1,8 @@ +""" +Version information for FramePack Studio. +This module provides a central location for version information. +""" + +# Version information +APP_VERSION = "0.5" # Numeric version for metadata +APP_VERSION_DISPLAY = f"v{APP_VERSION}" # Display version for toolbar diff --git a/modules/video_queue.py b/modules/video_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..52b09d1a0934bc9b5d07963b3efad54721a4a7bd --- /dev/null +++ b/modules/video_queue.py @@ -0,0 +1,1655 @@ +import threading +import time +import uuid +import json +import os +import zipfile +import shutil +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, Any, Optional, List +import queue as queue_module # Renamed to avoid conflicts +import io +import base64 +from PIL import Image +import numpy as np + +from diffusers_helper.thread_utils import AsyncStream +from modules.pipelines.metadata_utils import create_metadata +from modules.settings import Settings +from diffusers_helper.gradio.progress_bar import make_progress_bar_html + + +# Simple LIFO queue implementation to avoid dependency on queue.LifoQueue +class SimpleLifoQueue: + def __init__(self): + self._queue = [] + self._mutex = threading.Lock() + self._not_empty = threading.Condition(self._mutex) + + def put(self, item): + with self._mutex: + self._queue.append(item) + self._not_empty.notify() + + def get(self): + with self._not_empty: + while not self._queue: + self._not_empty.wait() + return self._queue.pop() + + def task_done(self): + pass # For compatibility with queue.Queue + + +class JobStatus(Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobType(Enum): + SINGLE = "single" + GRID = "grid" + + +@dataclass +class Job: + id: str + params: Dict[str, Any] + status: JobStatus = JobStatus.PENDING + job_type: JobType = JobType.SINGLE + child_job_ids: List[str] = field(default_factory=list) + parent_job_id: Optional[str] = None + created_at: float = field(default_factory=time.time) + started_at: Optional[float] = None + completed_at: Optional[float] = None + error: Optional[str] = None + result: Optional[str] = None + progress_data: Optional[Dict] = None + queue_position: Optional[int] = None + stream: Optional[Any] = None + input_image: Optional[np.ndarray] = None + latent_type: Optional[str] = None + thumbnail: Optional[str] = None + generation_type: Optional[str] = None # Added generation_type + input_image_saved: bool = False # Flag to track if input image has been saved + end_frame_image_saved: bool = False # Flag to track if end frame image has been saved + + def __post_init__(self): + # Store generation type + self.generation_type = self.params.get('model_type', 'Original') # Initialize generation_type + + # Store input image or latent type + if 'input_image' in self.params and self.params['input_image'] is not None: + self.input_image = self.params['input_image'] + # Create thumbnail + if isinstance(self.input_image, np.ndarray): + # Handle numpy array (image) + img = Image.fromarray(self.input_image) + img.thumbnail((100, 100)) + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + elif isinstance(self.input_image, str): + # Handle string (video path) + try: + print(f"Attempting to extract thumbnail from video: {self.input_image}") + # Try to extract frames from the video using imageio + import imageio + + # Check if the file exists + if not os.path.exists(self.input_image): + print(f"Video file not found: {self.input_image}") + raise FileNotFoundError(f"Video file not found: {self.input_image}") + + # Create outputs directory if it doesn't exist + os.makedirs("outputs", exist_ok=True) + + # Try to open the video file + try: + reader = imageio.get_reader(self.input_image) + print(f"Successfully opened video file with imageio") + except Exception as e: + print(f"Failed to open video with imageio: {e}") + raise + + # Get the total number of frames + num_frames = None + try: + # Try to get the number of frames from metadata + meta_data = reader.get_meta_data() + print(f"Video metadata: {meta_data}") + num_frames = meta_data.get('nframes') + if num_frames is None or num_frames == float('inf'): + print("Number of frames not available in metadata") + # If not available, try to count frames + if hasattr(reader, 'count_frames'): + print("Trying to count frames...") + num_frames = reader.count_frames() + print(f"Counted {num_frames} frames") + except Exception as e: + print(f"Error getting frame count: {e}") + num_frames = None + + # If we couldn't determine the number of frames, read the last frame by iterating + if num_frames is None or num_frames == float('inf'): + print("Reading frames by iteration to find the last one") + # Read frames until we reach the end + frame_count = 0 + first_frame = None + last_frame = None + try: + for frame in reader: + if frame_count == 0: + first_frame = frame + last_frame = frame + frame_count += 1 + # Print progress every 100 frames + if frame_count % 100 == 0: + print(f"Read {frame_count} frames...") + print(f"Finished reading {frame_count} frames") + + # Save the first frame if available + if first_frame is not None: + print(f"Found first frame with shape: {first_frame.shape}") + # DEBUG IMAGE SAVING REMOVED + except Exception as e: + print(f"Error reading frames: {e}") + + if last_frame is not None: + print(f"Found last frame with shape: {last_frame.shape}") + + # DEBUG IMAGE SAVING REMOVED + # Use the last frame for the thumbnail + img = Image.fromarray(last_frame) + img.thumbnail((100, 100)) + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + print("Successfully created thumbnail from last frame") + else: + print("No frames were read, using red thumbnail") + # Fallback to red thumbnail if no frames were read - more visible for debugging + img = Image.new('RGB', (100, 100), (255, 0, 0)) # Red for video + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + else: + # If we know the number of frames, try to get multiple frames for debugging + try: + # Try to get the first frame + first_frame = None + try: + first_frame = reader.get_data(0) + print(f"Got first frame with shape: {first_frame.shape}") + + # DEBUG IMAGE SAVING REMOVED + except Exception as e: + print(f"Error getting first frame: {e}") + + # Try to get a middle frame + middle_frame = None + try: + middle_frame_idx = int(num_frames / 2) + middle_frame = reader.get_data(middle_frame_idx) + print(f"Got middle frame (frame {middle_frame_idx}) with shape: {middle_frame.shape}") + + # DEBUG IMAGE SAVING REMOVED + except Exception as e: + print(f"Error getting middle frame: {e}") + + # Try to get the last frame + last_frame = None + try: + last_frame_idx = int(num_frames) - 1 + last_frame = reader.get_data(last_frame_idx) + print(f"Got last frame (frame {last_frame_idx}) with shape: {last_frame.shape}") + + # DEBUG IMAGE SAVING REMOVED + except Exception as e: + print(f"Error getting last frame: {e}") + + # If we couldn't get the last frame directly, try to get it by iterating + if last_frame is None: + print("Trying to get last frame by iterating through all frames") + try: + for frame in reader: + last_frame = frame + + if last_frame is not None: + print(f"Got last frame by iteration with shape: {last_frame.shape}") + + # DEBUG IMAGE SAVING REMOVED + except Exception as e: + print(f"Error getting last frame by iteration: {e}") + + # Use the last frame for the thumbnail if available, otherwise use the middle or first frame + frame_for_thumbnail = last_frame if last_frame is not None else (middle_frame if middle_frame is not None else first_frame) + + if frame_for_thumbnail is not None: + # Convert to PIL Image and create a thumbnail + img = Image.fromarray(frame_for_thumbnail) + img.thumbnail((100, 100)) + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + print("Successfully created thumbnail from frame") + else: + print("No frames were extracted, using blue thumbnail") + # Fallback to blue thumbnail if no frames were extracted + img = Image.new('RGB', (100, 100), (0, 0, 255)) # Blue for video + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + except Exception as e: + # Fallback to blue thumbnail on error + img = Image.new('RGB', (100, 100), (0, 0, 255)) # Blue for video + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + + # Close the reader + try: + reader.close() + print("Successfully closed video reader") + except Exception as e: + print(f"Error closing reader: {e}") + + except Exception as e: + print(f"Error extracting thumbnail from video: {e}") + import traceback + traceback.print_exc() + # Fallback to bright green thumbnail on error to make it more visible + img = Image.new('RGB', (100, 100), (0, 255, 0)) # Bright green for error + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + print("Created bright green fallback thumbnail") + else: + # Handle other types + self.thumbnail = None + elif 'latent_type' in self.params: + self.latent_type = self.params['latent_type'] + # Create a colored square based on latent type + color_map = { + "Black": (0, 0, 0), + "White": (255, 255, 255), + "Noise": (128, 128, 128), + "Green Screen": (0, 177, 64) + } + color = color_map.get(self.latent_type, (0, 0, 0)) + img = Image.new('RGB', (100, 100), color) + buffered = io.BytesIO() + img.save(buffered, format="PNG") + self.thumbnail = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode()}" + + +class VideoJobQueue: + def __init__(self): + self.queue = queue_module.Queue() # Using standard Queue instead of LifoQueue + self.jobs = {} + self.current_job = None + self.lock = threading.Lock() + self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True) + self.worker_thread.start() + self.worker_function = None # Will be set from outside + self.is_processing = False # Flag to track if we're currently processing a job + + def set_worker_function(self, worker_function): + """Set the worker function to use for processing jobs""" + self.worker_function = worker_function + + def serialize_job(self, job): + """Serialize a job to a JSON-compatible format""" + try: + # Create a simplified representation of the job + serialized = { + "id": job.id, + "status": job.status.value, + "created_at": job.created_at, + "started_at": job.started_at, + "completed_at": job.completed_at, + "error": job.error, + "result": job.result, + "queue_position": job.queue_position, + "generation_type": job.generation_type, + } + + # Add simplified params (excluding complex objects) + serialized_params = {} + for k, v in job.params.items(): + if k not in ["input_image", "end_frame_image", "stream"]: + # Try to include only JSON-serializable values + try: + # Test if value is JSON serializable + json.dumps({k: v}) + serialized_params[k] = v + except (TypeError, OverflowError): + # Skip non-serializable values + pass + + # Handle LoRA information specifically + # Only include selected LoRAs for the generation + if "selected_loras" in job.params and job.params["selected_loras"]: + selected_loras = job.params["selected_loras"] + # Ensure it's a list + if not isinstance(selected_loras, list): + selected_loras = [selected_loras] if selected_loras is not None else [] + + # Get LoRA values if available + lora_values = job.params.get("lora_values", []) + if not isinstance(lora_values, list): + lora_values = [lora_values] if lora_values is not None else [] + + # Get loaded LoRA names + lora_loaded_names = job.params.get("lora_loaded_names", []) + if not isinstance(lora_loaded_names, list): + lora_loaded_names = [lora_loaded_names] if lora_loaded_names is not None else [] + + # Create LoRA data dictionary + lora_data = {} + for lora_name in selected_loras: + try: + # Find the index of the LoRA in loaded names + idx = lora_loaded_names.index(lora_name) if lora_loaded_names else -1 + # Get the weight value + weight = lora_values[idx] if lora_values and idx >= 0 and idx < len(lora_values) else 1.0 + # Handle weight as list + if isinstance(weight, list): + weight_value = weight[0] if weight and len(weight) > 0 else 1.0 + else: + weight_value = weight + # Store as float + lora_data[lora_name] = float(weight_value) + except (ValueError, IndexError): + # Default weight if not found + lora_data[lora_name] = 1.0 + except Exception as e: + print(f"Error processing LoRA {lora_name}: {e}") + lora_data[lora_name] = 1.0 + + # Add to serialized params + serialized_params["loras"] = lora_data + + serialized["params"] = serialized_params + + # Don't include the thumbnail as it can be very large and cause issues + # if job.thumbnail: + # serialized["thumbnail"] = job.thumbnail + + return serialized + except Exception as e: + print(f"Error serializing job {job.id}: {e}") + # Return minimal information that should always be serializable + return { + "id": job.id, + "status": job.status.value, + "error": f"Error serializing: {str(e)}" + } + + def save_queue_to_json(self): + """Save the current queue to queue.json using the central metadata utility""" + try: + # Make a copy of job IDs to avoid holding the lock while serializing + with self.lock: + job_ids = list(self.jobs.keys()) + + # Create a settings instance + settings = Settings() + + # Create a directory to store queue images if it doesn't exist + queue_images_dir = "queue_images" + os.makedirs(queue_images_dir, exist_ok=True) + + # First, ensure all images are saved + for job_id in job_ids: + job = self.get_job(job_id) + if job: + # Save input image to disk if it exists and hasn't been saved yet + if 'input_image' in job.params and isinstance(job.params['input_image'], np.ndarray) and not job.input_image_saved: + input_image_path = os.path.join(queue_images_dir, f"{job_id}_input.png") + try: + Image.fromarray(job.params['input_image']).save(input_image_path) + print(f"Saved input image for job {job_id} to {input_image_path}") + # Mark the image as saved + job.input_image_saved = True + except Exception as e: + print(f"Error saving input image for job {job_id}: {e}") + + # Save end frame image to disk if it exists and hasn't been saved yet + if 'end_frame_image' in job.params and isinstance(job.params['end_frame_image'], np.ndarray) and not job.end_frame_image_saved: + end_frame_image_path = os.path.join(queue_images_dir, f"{job_id}_end_frame.png") + try: + Image.fromarray(job.params['end_frame_image']).save(end_frame_image_path) + print(f"Saved end frame image for job {job_id} to {end_frame_image_path}") + # Mark the end frame image as saved + job.end_frame_image_saved = True + except Exception as e: + print(f"Error saving end frame image for job {job_id}: {e}") + + # Now serialize jobs with the updated image saved flags + serialized_jobs = {} + for job_id in job_ids: + job = self.get_job(job_id) + if job: + # Try to use metadata_utils.create_metadata if possible + try: + # Create metadata using the central utility + metadata = create_metadata(job.params, job.id, settings.settings) + + # Add job status and other fields not included in metadata + metadata.update({ + "id": job.id, + "status": job.status.value, + "created_at": job.created_at, + "started_at": job.started_at, + "completed_at": job.completed_at, + "error": job.error, + "result": job.result, + "queue_position": job.queue_position, + }) + + # Add image paths to metadata if they've been saved + if job.input_image_saved: + input_image_path = os.path.join(queue_images_dir, f"{job_id}_input.png") + if os.path.exists(input_image_path): + metadata["saved_input_image_path"] = input_image_path + + if job.end_frame_image_saved: + end_frame_image_path = os.path.join(queue_images_dir, f"{job_id}_end_frame.png") + if os.path.exists(end_frame_image_path): + metadata["saved_end_frame_image_path"] = end_frame_image_path + + serialized_jobs[job_id] = metadata + except Exception as e: + print(f"Error using metadata_utils for job {job_id}: {e}") + # Fall back to the old serialization method + serialized_jobs[job_id] = self.serialize_job(job) + + # Save to file + with open("queue.json", "w") as f: + json.dump(serialized_jobs, f, indent=2) + + # Clean up images for jobs that no longer exist + self.cleanup_orphaned_images(job_ids) + self.cleanup_orphaned_videos(job_ids) + + print(f"Saved {len(serialized_jobs)} jobs to queue.json") + except Exception as e: + print(f"Error saving queue to JSON: {e}") + + def cleanup_orphaned_videos(self, current_job_ids_uuids): # Renamed arg for clarity + """ + Remove video files from input_files_dir for jobs that no longer exist + or whose input_image_path does not point to them. + + Args: + current_job_ids_uuids: List of job UUIDs currently in self.jobs + """ + try: + # Get the input_files_dir from settings to be robust + settings = Settings() + input_files_dir = settings.get("input_files_dir", "input_files") + if not os.path.exists(input_files_dir): + return + + # Normalize the managed input_files_dir path once + norm_input_files_dir = os.path.normpath(input_files_dir) + referenced_video_paths = set() + + with self.lock: # Access self.jobs safely + for job_id_uuid in current_job_ids_uuids: # Iterate using the provided UUIDs + job = self.jobs.get(job_id_uuid) + if not (job and job.params): + continue + + # Collect all potential video paths from the job parameters + # Check for strings to avoid TypeError + paths_to_consider = [] + p1 = job.params.get("input_image") # Primary path used by worker + if isinstance(p1, str): + paths_to_consider.append(p1) + + p2 = job.params.get("input_image_path") # Secondary/metadata path + if isinstance(p2, str) and p2 not in paths_to_consider: + paths_to_consider.append(p2) + + p3 = job.params.get("input_video") # Explicitly set during import + if isinstance(p3, str) and p3 not in paths_to_consider: + paths_to_consider.append(p3) + + for rel_or_abs_path in paths_to_consider: + # Resolve to absolute path. If already absolute, abspath does nothing. + # If relative, it's resolved against CWD (current working directory). + abs_path = os.path.abspath(rel_or_abs_path) + norm_abs_path = os.path.normpath(abs_path) + # Check if this path is within the managed input_files_dir + if norm_abs_path.startswith(norm_input_files_dir): + referenced_video_paths.add(norm_abs_path) + + removed_count = 0 + for filename in os.listdir(input_files_dir): + if filename.endswith(".mp4"): # Only process MP4 files + file_path_to_check = os.path.normpath(os.path.join(input_files_dir, filename)) + + if file_path_to_check not in referenced_video_paths: + try: + os.remove(file_path_to_check) + removed_count += 1 + print(f"Removed orphaned video: {filename} (path: {file_path_to_check})") + except Exception as e: + print(f"Error removing orphaned video {filename}: {e}") + if removed_count > 0: + print(f"Cleaned up {removed_count} orphaned videos from {input_files_dir}") + except Exception as e: + print(f"Error cleaning up orphaned videos: {e}") + import traceback + traceback.print_exc() + + def cleanup_orphaned_images(self, current_job_ids): + """ + Remove image files for jobs that no longer exist in the queue. + + Args: + current_job_ids: List of job IDs currently in the queue + """ + try: + queue_images_dir = "queue_images" + if not os.path.exists(queue_images_dir): + return + + # Convert to set for faster lookups + current_job_ids = set(current_job_ids) + + # Check all files in the queue_images directory + removed_count = 0 + for filename in os.listdir(queue_images_dir): + # Only process PNG files with our naming pattern + if filename.endswith(".png") and ("_input.png" in filename or "_end_frame.png" in filename): + # Extract job ID from filename + parts = filename.split("_") + if len(parts) >= 2: + job_id = parts[0] + + # If job ID is not in current jobs, remove the file + if job_id not in current_job_ids: + file_path = os.path.join(queue_images_dir, filename) + try: + os.remove(file_path) + removed_count += 1 + print(f"Removed orphaned image: {filename}") + except Exception as e: + print(f"Error removing orphaned image {filename}: {e}") + + if removed_count > 0: + print(f"Cleaned up {removed_count} orphaned images") + except Exception as e: + print(f"Error cleaning up orphaned images: {e}") + + + def synchronize_queue_images(self): + """ + Synchronize the queue_images directory with the current jobs in the queue. + This ensures all necessary images are saved and only images for removed jobs are deleted. + """ + try: + queue_images_dir = "queue_images" + os.makedirs(queue_images_dir, exist_ok=True) + + # Get all current job IDs + with self.lock: + current_job_ids = set(self.jobs.keys()) + + # Get all image files in the queue_images directory + existing_image_files = set() + if os.path.exists(queue_images_dir): + for filename in os.listdir(queue_images_dir): + if filename.endswith(".png") and ("_input.png" in filename or "_end_frame.png" in filename): + existing_image_files.add(filename) + + # Extract job IDs from filenames + file_job_ids = set() + for filename in existing_image_files: + # Extract job ID from filename (format: "{job_id}_input.png" or "{job_id}_end_frame.png") + parts = filename.split("_") + if len(parts) >= 2: + job_id = parts[0] + file_job_ids.add(job_id) + + # Find job IDs in files that are no longer in the queue + removed_job_ids = file_job_ids - current_job_ids + + # Delete images for jobs that have been removed from the queue + removed_count = 0 + for job_id in removed_job_ids: + input_image_path = os.path.join(queue_images_dir, f"{job_id}_input.png") + end_frame_image_path = os.path.join(queue_images_dir, f"{job_id}_end_frame.png") + + if os.path.exists(input_image_path): + try: + os.remove(input_image_path) + removed_count += 1 + print(f"Removed image for deleted job: {input_image_path}") + except Exception as e: + print(f"Error removing image {input_image_path}: {e}") + + if os.path.exists(end_frame_image_path): + try: + os.remove(end_frame_image_path) + removed_count += 1 + print(f"Removed image for deleted job: {end_frame_image_path}") + except Exception as e: + print(f"Error removing image {end_frame_image_path}: {e}") + + # Now ensure all current jobs have their images saved + saved_count = 0 + with self.lock: + for job_id, job in self.jobs.items(): + # Only save images for running or completed jobs + if job.status in [JobStatus.RUNNING, JobStatus.COMPLETED]: + # Save input image if it exists and hasn't been saved yet + if 'input_image' in job.params and isinstance(job.params['input_image'], np.ndarray) and not job.input_image_saved: + input_image_path = os.path.join(queue_images_dir, f"{job_id}_input.png") + try: + Image.fromarray(job.params['input_image']).save(input_image_path) + job.input_image_saved = True + saved_count += 1 + print(f"Saved input image for job {job_id}") + except Exception as e: + print(f"Error saving input image for job {job_id}: {e}") + + # Save end frame image if it exists and hasn't been saved yet + if 'end_frame_image' in job.params and isinstance(job.params['end_frame_image'], np.ndarray) and not job.end_frame_image_saved: + end_frame_image_path = os.path.join(queue_images_dir, f"{job_id}_end_frame.png") + try: + Image.fromarray(job.params['end_frame_image']).save(end_frame_image_path) + job.end_frame_image_saved = True + saved_count += 1 + print(f"Saved end frame image for job {job_id}") + except Exception as e: + print(f"Error saving end frame image for job {job_id}: {e}") + + # Save the queue to ensure the image paths are properly referenced + self.save_queue_to_json() + + if removed_count > 0 or saved_count > 0: + print(f"Queue image synchronization: removed {removed_count} images, saved {saved_count} images") + + except Exception as e: + print(f"Error synchronizing queue images: {e}") + + + def add_job(self, params, job_type=JobType.SINGLE, child_job_params_list=None, parent_job_id=None): + """Add a job to the queue and return its ID""" + job_id = str(uuid.uuid4()) + + # For grid jobs, create child jobs first + child_job_ids = [] + if job_type == JobType.GRID and child_job_params_list: + with self.lock: + for child_params in child_job_params_list: + child_job_id = str(uuid.uuid4()) + child_job_ids.append(child_job_id) + child_job = Job( + id=child_job_id, + params=child_params, + status=JobStatus.PENDING, + job_type=JobType.SINGLE, # Children are single jobs + parent_job_id=job_id, + created_at=time.time(), + progress_data={}, + stream=AsyncStream(), + input_image_saved=False, + end_frame_image_saved=False + ) + self.jobs[child_job_id] = child_job + print(f" - Created child job {child_job_id} for grid job {job_id}") + + job = Job( + id=job_id, + params=params, + status=JobStatus.PENDING, + job_type=job_type, + child_job_ids=child_job_ids, + parent_job_id=parent_job_id, + created_at=time.time(), + progress_data={}, + stream=AsyncStream(), + input_image_saved=False, + end_frame_image_saved=False + ) + + with self.lock: + print(f"Adding job {job_id} (type: {job_type.value}) to queue.") + self.jobs[job_id] = job + self.queue.put(job_id) # Only the parent (or single) job is added to the queue initially + + # Save the queue to JSON after adding a new job (outside the lock) + try: + self.save_queue_to_json() + except Exception as e: + print(f"Error saving queue to JSON after adding job: {e}") + + return job_id + + def get_job(self, job_id): + """Get job by ID""" + with self.lock: + return self.jobs.get(job_id) + + def get_all_jobs(self): + """Get all jobs""" + with self.lock: + return list(self.jobs.values()) + + def cancel_job(self, job_id): + """Cancel a pending job""" + with self.lock: + job = self.jobs.get(job_id) + if not job: + return False + + if job.status == JobStatus.PENDING: + job.status = JobStatus.CANCELLED + job.completed_at = time.time() # Mark completion time + result = True + elif job.status == JobStatus.RUNNING: + # Send cancel signal to the job's stream + if hasattr(job, 'stream') and job.stream: + job.stream.input_queue.push('end') + + # Mark job as cancelled (this will be confirmed when the worker processes the end signal) + job.status = JobStatus.CANCELLED + job.completed_at = time.time() # Mark completion time + + # Let the worker loop handle the transition to the next job + # This ensures the current job is fully processed before switching + # DEBUG PRINT REMOVED + result = True + else: + result = False + + # Save the queue to JSON after cancelling a job (outside the lock) + if result: + try: + self.save_queue_to_json() + except Exception as e: + print(f"Error saving queue to JSON after cancelling job: {e}") + + return result + + def clear_queue(self): + """Cancel all pending jobs in the queue""" + cancelled_count = 0 + try: + # First, make a copy of all pending job IDs to avoid modifying the dictionary during iteration + with self.lock: + # Get all pending job IDs + pending_job_ids = [job_id for job_id, job in self.jobs.items() + if job.status == JobStatus.PENDING] + + # Cancel each pending job individually + for job_id in pending_job_ids: + try: + with self.lock: + job = self.jobs.get(job_id) + if job and job.status == JobStatus.PENDING: + job.status = JobStatus.CANCELLED + job.completed_at = time.time() + cancelled_count += 1 + except Exception as e: + print(f"Error cancelling job {job_id}: {e}") + + # Now clear the queue + with self.lock: + # Clear the queue (this doesn't affect running jobs) + queue_items_cleared = 0 + try: + while not self.queue.empty(): + try: + self.queue.get_nowait() + self.queue.task_done() + queue_items_cleared += 1 + except queue_module.Empty: + break + except Exception as e: + print(f"Error clearing queue: {e}") + + # Save the updated queue state + try: + self.save_queue_to_json() + except Exception as e: + print(f"Error saving queue state: {e}") + + # Synchronize queue images after clearing the queue + if cancelled_count > 0: + self.synchronize_queue_images() + + print(f"Cleared {cancelled_count} jobs from the queue") + return cancelled_count + except Exception as e: + import traceback + print(f"Error in clear_queue: {e}") + traceback.print_exc() + return 0 + + def clear_completed_jobs(self): + """Remove cancelled or completed jobs from the queue""" + removed_count = 0 + try: + # First, make a copy of all completed/cancelled job IDs to avoid modifying the dictionary during iteration + with self.lock: + # Get all completed or cancelled job IDs + completed_job_ids = [job_id for job_id, job in self.jobs.items() + if job.status in [JobStatus.COMPLETED, JobStatus.CANCELLED]] + + # Remove each completed/cancelled job individually + for job_id in completed_job_ids: + try: + with self.lock: + if job_id in self.jobs: + del self.jobs[job_id] + removed_count += 1 + except Exception as e: + print(f"Error removing job {job_id}: {e}") + + # Save the updated queue state + try: + self.save_queue_to_json() + except Exception as e: + print(f"Error saving queue state: {e}") + + # Synchronize queue images after removing completed jobs + if removed_count > 0: + self.synchronize_queue_images() + + print(f"Removed {removed_count} completed/cancelled jobs from the queue") + return removed_count + except Exception as e: + import traceback + print(f"Error in clear_completed_jobs: {e}") + traceback.print_exc() + return 0 + + def get_queue_position(self, job_id): + """Get position in queue (0 = currently running)""" + with self.lock: + job = self.jobs.get(job_id) + if not job: + return None + + if job.status == JobStatus.RUNNING: + return 0 + + if job.status != JobStatus.PENDING: + return None + + # Count pending jobs ahead in queue + position = 1 # Start at 1 because 0 means running + for j in self.jobs.values(): + if (j.status == JobStatus.PENDING and + j.created_at < job.created_at): + position += 1 + return position + + def update_job_progress(self, job_id, progress_data): + """Update job progress data""" + with self.lock: + job = self.jobs.get(job_id) + if job: + job.progress_data = progress_data + + def export_queue_to_zip(self, output_path=None): + """Export the current queue to a zip file containing queue.json and queue_images directory + + Args: + output_path: Path to save the zip file. If None, uses 'queue_export.zip' in the configured output directory. + + Returns: + str: Path to the created zip file + """ + try: + # Get the output directory from settings + settings = Settings() + output_dir = settings.get("output_dir", "outputs") + os.makedirs(output_dir, exist_ok=True) + + # Use default path if none provided + if output_path is None: + output_path = os.path.join(output_dir, "queue_export.zip") + + # Make sure queue.json is up to date + self.save_queue_to_json() + + # Create a zip file + with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + # Add queue.json to the zip file + if os.path.exists("queue.json"): + zipf.write("queue.json") + print(f"Added queue.json to {output_path}") + else: + print("Warning: queue.json not found, creating an empty one") + with open("queue.json", "w") as f: + json.dump({}, f) + zipf.write("queue.json") + + # Add queue_images directory to the zip file if it exists + queue_images_dir = "queue_images" + if os.path.exists(queue_images_dir) and os.path.isdir(queue_images_dir): + for root, _, files in os.walk(queue_images_dir): + for file in files: + file_path = os.path.join(root, file) + # Add file to zip with path relative to queue_images_dir + arcname = os.path.join(os.path.basename(queue_images_dir), file) + zipf.write(file_path, arcname) + print(f"Added {file_path} to {output_path}") + else: + print(f"Warning: {queue_images_dir} directory not found or empty") + # Create the directory if it doesn't exist + os.makedirs(queue_images_dir, exist_ok=True) + + # Add input_files directory to the zip file if it exists + input_files_dir = "input_files" + if os.path.exists(input_files_dir) and os.path.isdir(input_files_dir): + for root, _, files in os.walk(input_files_dir): + for file in files: + file_path = os.path.join(root, file) + # Add file to zip with path relative to input_files_dir + arcname = os.path.join(os.path.basename(input_files_dir), file) + zipf.write(file_path, arcname) + print(f"Added {file_path} to {output_path}") + else: + print(f"Warning: {input_files_dir} directory not found or empty") + # Create the directory if it doesn't exist + os.makedirs(input_files_dir, exist_ok=True) + + print(f"Queue exported to {output_path}") + return output_path + + except Exception as e: + import traceback + print(f"Error exporting queue to zip: {e}") + traceback.print_exc() + return None + + def load_queue_from_json(self, file_path=None): + """Load queue from a JSON file or zip file + + Args: + file_path: Path to the JSON or ZIP file. If None, uses 'queue.json' in the current directory. + + Returns: + int: Number of jobs loaded + """ + try: + # Import required modules + import os + import json + from pathlib import PurePath + + # Use default path if none provided + if file_path is None: + file_path = "queue.json" + + # Check if file exists + if not os.path.exists(file_path): + print(f"Queue file not found: {file_path}") + return 0 + + # Check if it's a zip file + if file_path.lower().endswith('.zip'): + return self._load_queue_from_zip(file_path) + + # Load the JSON data + with open(file_path, 'r') as f: + serialized_jobs = json.load(f) + + # Count of jobs loaded + loaded_count = 0 + + # Process each job + with self.lock: + for job_id, job_data in serialized_jobs.items(): + # Skip if job already exists + if job_id in self.jobs: + print(f"Job {job_id} already exists, skipping") + continue + + # Skip completed, failed, or cancelled jobs + status = job_data.get('status') + if status in ['completed', 'failed', 'cancelled']: + print(f"Skipping job {job_id} with status {status}") + continue + + # If the job was running when saved, we'll need to set it as the current job + was_running = (status == 'running') + + # Extract relevant fields to construct params + params = { + # Basic parameters + 'model_type': job_data.get('model_type', 'Original'), + 'prompt_text': job_data.get('prompt', ''), + 'n_prompt': job_data.get('negative_prompt', ''), + 'seed': job_data.get('seed', 0), + 'steps': job_data.get('steps', 25), + 'cfg': job_data.get('cfg', 1.0), + 'gs': job_data.get('gs', 10.0), + 'rs': job_data.get('rs', 0.0), + 'latent_type': job_data.get('latent_type', 'Black'), + 'total_second_length': job_data.get('total_second_length', 6), + 'blend_sections': job_data.get('blend_sections', 4), + 'latent_window_size': job_data.get('latent_window_size', 9), + 'resolutionW': job_data.get('resolutionW', 640), + 'resolutionH': job_data.get('resolutionH', 640), + 'use_magcache': job_data.get('use_magcache', False), + 'magcache_threshold': job_data.get('magcache_threshold', 0.1), + 'magcache_max_consecutive_skips': job_data.get('magcache_max_consecutive_skips', 2), + 'magcache_retention_ratio': job_data.get('magcache_retention_ratio', 0.25), + + # Initialize image parameters + 'input_image': None, + 'end_frame_image': None, + 'end_frame_strength': job_data.get('end_frame_strength', 1.0), + 'use_teacache': job_data.get('use_teacache', True), + 'teacache_num_steps': job_data.get('teacache_num_steps', 25), + 'teacache_rel_l1_thresh': job_data.get('teacache_rel_l1_thresh', 0.15), + 'has_input_image': job_data.get('has_input_image', True), + 'combine_with_source': job_data.get('combine_with_source', False), + } + + # Load input image from disk if saved path exists + if "saved_input_image_path" in job_data and os.path.exists(job_data["saved_input_image_path"]): + try: + input_image_path = job_data["saved_input_image_path"] + print(f"Loading input image from {input_image_path}") + input_image = np.array(Image.open(input_image_path)) + params['input_image'] = input_image + params['input_image_path'] = input_image_path # Store the path for reference + params['has_input_image'] = True + except Exception as e: + print(f"Error loading input image for job {job_id}: {e}") + + # Load video from disk if saved path exists + input_video_val = job_data.get("input_video") # Get value safely + if isinstance(input_video_val, str): # Check if it's a string path + if os.path.exists(input_video_val): # Now it's safe to call os.path.exists + try: + video_path = input_video_val # Use the validated string path + print(f"Loading video from {video_path}") + params['input_image'] = video_path + params['input_image_path'] = video_path + params['has_input_image'] = True + except Exception as e: + print(f"Error loading video for job {job_id}: {e}") + + # Load end frame image from disk if saved path exists + if "saved_end_frame_image_path" in job_data and os.path.exists(job_data["saved_end_frame_image_path"]): + try: + end_frame_image_path = job_data["saved_end_frame_image_path"] + print(f"Loading end frame image from {end_frame_image_path}") + end_frame_image = np.array(Image.open(end_frame_image_path)) + params['end_frame_image'] = end_frame_image + params['end_frame_image_path'] = end_frame_image_path # Store the path for reference + # Make sure end_frame_strength is set if this is an endframe model + if params['model_type'] == "Original with Endframe" or params['model_type'] == "F1 with Endframe": + if 'end_frame_strength' not in params or params['end_frame_strength'] is None: + params['end_frame_strength'] = job_data.get('end_frame_strength', 1.0) + print(f"Set end_frame_strength to {params['end_frame_strength']} for job {job_id}") + except Exception as e: + print(f"Error loading end frame image for job {job_id}: {e}") + + # Add LoRA information if present + if 'loras' in job_data: + lora_data = job_data.get('loras', {}) + selected_loras = list(lora_data.keys()) + lora_values = list(lora_data.values()) + params['selected_loras'] = selected_loras + params['lora_values'] = lora_values + + # Ensure the selected LoRAs are also in lora_loaded_names + # This is critical for metadata_utils.create_metadata to find the LoRAs + from modules.settings import Settings + settings = Settings() + lora_dir = settings.get("lora_dir", "loras") + + # Get the current lora_loaded_names from the system + import os + from pathlib import PurePath + current_lora_names = [] + if os.path.isdir(lora_dir): + for root, _, files in os.walk(lora_dir): + for file in files: + if file.endswith('.safetensors') or file.endswith('.pt'): + lora_relative_path = os.path.relpath(os.path.join(root, file), lora_dir) + lora_name = str(PurePath(lora_relative_path).with_suffix('')) + current_lora_names.append(lora_name) + + # Combine the selected LoRAs with the current lora_loaded_names + # This ensures that all selected LoRAs are in lora_loaded_names + combined_lora_names = list(set(current_lora_names + selected_loras)) + params['lora_loaded_names'] = combined_lora_names + + print(f"Loaded LoRA data for job {job_id}: {lora_data}") + print(f"Combined lora_loaded_names: {combined_lora_names}") + + # Get settings for output_dir and metadata_dir + settings = Settings() + output_dir = settings.get("output_dir") + metadata_dir = settings.get("metadata_dir") + input_files_dir = settings.get("input_files_dir") + + # Add these directories to the params + params['output_dir'] = output_dir + params['metadata_dir'] = metadata_dir + params['input_files_dir'] = input_files_dir + + # Create a dummy preview image for the job + dummy_preview = np.zeros((64, 64, 3), dtype=np.uint8) + + # Create progress data with the dummy preview + from diffusers_helper.gradio.progress_bar import make_progress_bar_html + initial_progress_data = { + 'preview': dummy_preview, + 'desc': 'Imported job...', + 'html': make_progress_bar_html(0, 'Imported job...') + } + + # Create a dummy preview image for the job + dummy_preview = np.zeros((64, 64, 3), dtype=np.uint8) + + # Create progress data with the dummy preview + from diffusers_helper.gradio.progress_bar import make_progress_bar_html + initial_progress_data = { + 'preview': dummy_preview, + 'desc': 'Imported job...', + 'html': make_progress_bar_html(0, 'Imported job...') + } + + # Create a new job + job = Job( + id=job_id, + params=params, + status=JobStatus(job_data.get('status', 'pending')), + created_at=job_data.get('created_at', time.time()), + progress_data={}, + stream=AsyncStream(), + # Mark images as saved if their paths exist in the job data + input_image_saved="saved_input_image_path" in job_data and os.path.exists(job_data["saved_input_image_path"]), + end_frame_image_saved="saved_end_frame_image_path" in job_data and os.path.exists(job_data["saved_end_frame_image_path"]) + ) + + # Add job to the internal jobs dictionary + self.jobs[job_id] = job + + # If a job was marked "running" in the JSON, reset it to "pending" + # and add it to the processing queue. + if was_running: + print(f"Job {job_id} was 'running', resetting to 'pending' and adding to queue.") + job.status = JobStatus.PENDING + job.started_at = None # Clear started_at for re-queued job + job.progress_data = {} # Reset progress + + # Add all non-completed/failed/cancelled jobs (now including reset 'running' ones) to the processing queue + if job.status == JobStatus.PENDING: + self.queue.put(job_id) + loaded_count += 1 + + # Synchronize queue images after loading the queue + self.synchronize_queue_images() + + print(f"Loaded {loaded_count} pending jobs from {file_path}") + return loaded_count + + except Exception as e: + import traceback + print(f"Error loading queue from JSON: {e}") + traceback.print_exc() + return 0 + + def _load_queue_from_zip(self, zip_path): + """Load queue from a zip file + + Args: + zip_path: Path to the zip file + + Returns: + int: Number of jobs loaded + """ + try: + # Create a temporary directory to extract the zip file + temp_dir = "temp_queue_import" + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + os.makedirs(temp_dir, exist_ok=True) + + # Extract the zip file + with zipfile.ZipFile(zip_path, 'r') as zipf: + zipf.extractall(temp_dir) + + # Check if queue.json exists in the extracted files + queue_json_path = os.path.join(temp_dir, "queue.json") + if not os.path.exists(queue_json_path): + print(f"queue.json not found in {zip_path}") + shutil.rmtree(temp_dir) + return 0 + + # Define target_queue_images_dir and ensure it exists + # This needs to be defined regardless of whether queue_images exists in the zip, + # as it's used later for path updates. + target_queue_images_dir = "queue_images" + os.makedirs(target_queue_images_dir, exist_ok=True) + + # Check if queue_images directory exists in the extracted files + queue_images_dir = os.path.join(temp_dir, "queue_images") + if os.path.exists(queue_images_dir) and os.path.isdir(queue_images_dir): + # Copy all files from the extracted queue_images directory to the target directory + for file in os.listdir(queue_images_dir): + src_path = os.path.join(queue_images_dir, file) + dst_path = os.path.join(target_queue_images_dir, file) + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + print(f"Copied {src_path} to {dst_path}") + + # Check if input_files directory exists in the extracted files + input_files_dir = os.path.join(temp_dir, "input_files") + print(f"DEBUG: Checking for input_files directory in zip: {input_files_dir}") # DEBUG + if os.path.exists(input_files_dir) and os.path.isdir(input_files_dir): + print(f"DEBUG: Found input_files directory in zip. Contents: {os.listdir(input_files_dir)}") # DEBUG + # Copy the input_files directory to the current directory + target_input_files_dir = "input_files" + os.makedirs(target_input_files_dir, exist_ok=True) + + # Copy all files from the extracted input_files directory to the target directory + for file in os.listdir(input_files_dir): + print(f"DEBUG: Processing file from zip's input_files: {file}") # DEBUG + src_path = os.path.join(input_files_dir, file) + dst_path = os.path.join(target_input_files_dir, file) + if os.path.isfile(src_path): + print(f"DEBUG: Attempting to copy video file: {src_path} to {dst_path}") # DEBUG + shutil.copy2(src_path, dst_path) + print(f"Copied {src_path} to {dst_path}") + else: # DEBUG + print(f"DEBUG: Skipped copy, {src_path} is not a file.") # DEBUG + else: # DEBUG + print(f"DEBUG: Directory {input_files_dir} does not exist or is not a directory.") # DEBUG + + # Update paths in the queue.json file to reflect the new location of the images + try: + with open(queue_json_path, 'r') as f: + queue_data = json.load(f) + + # Update paths for each job + for job_id, job_data in queue_data.items(): + # Check for files with job_id in the name to identify input and end frame images + input_image_filename = f"{job_id}_input.png" + end_frame_image_filename = f"{job_id}_end_frame.png" + + # Check if these files exist in the target directory + input_image_path = os.path.join(target_queue_images_dir, input_image_filename) + end_frame_image_path = os.path.join(target_queue_images_dir, end_frame_image_filename) + + # Update paths in job_data + if os.path.exists(input_image_path): + job_data["saved_input_image_path"] = input_image_path + print(f"Updated input image path for job {job_id}: {input_image_path}") + elif "saved_input_image_path" in job_data: + # Fallback to updating the existing path + job_data["saved_input_image_path"] = os.path.join(target_queue_images_dir, os.path.basename(job_data["saved_input_image_path"])) + print(f"Updated existing input image path for job {job_id}") + + if os.path.exists(end_frame_image_path): + job_data["saved_end_frame_image_path"] = end_frame_image_path + print(f"Updated end frame image path for job {job_id}: {end_frame_image_path}") + elif "saved_end_frame_image_path" in job_data: + # Fallback to updating the existing path + job_data["saved_end_frame_image_path"] = os.path.join(target_queue_images_dir, os.path.basename(job_data["saved_end_frame_image_path"])) + print(f"Updated existing end frame image path for job {job_id}") + + # Handle video path update for job_data["input_video"] + current_input_video = job_data.get("input_video") + current_input_image_path = job_data.get("input_image_path") + model_type_for_job = job_data.get("model_type") + video_extensions = ('.mp4', '.mov', '.avi', '.mkv', '.webm', '.flv', '.gif') # Add more if needed + + # Prioritize input_video if it's already a string path + if isinstance(current_input_video, str): + job_data["input_video"] = os.path.join("input_files", os.path.basename(current_input_video)) + print(f"Updated video path for job {job_id} from 'input_video': {job_data['input_video']}") + # If input_video is None, but input_image_path is a video path (for Video/Video F1 models) + elif current_input_video is None and \ + isinstance(current_input_image_path, str) and \ + model_type_for_job in ("Video", "Video F1") and \ + current_input_image_path.lower().endswith(video_extensions): + + video_basename = os.path.basename(current_input_image_path) + job_data["input_video"] = os.path.join("input_files", video_basename) + print(f"Updated video path for job {job_id} from 'input_image_path' ('{current_input_image_path}') to '{job_data['input_video']}'") + elif current_input_video is None: + # If input_video is None and input_image_path is not a usable video path, keep input_video as None + print(f"Video path for job {job_id} is None and 'input_image_path' ('{current_input_image_path}') not used for 'input_video'. 'input_video' remains None.") + # Write the updated queue.json back to the file + with open(queue_json_path, 'w') as f: + json.dump(queue_data, f, indent=2) + + print(f"Updated image paths in queue.json to reflect new location") + except Exception as e: + print(f"Error updating paths in queue.json: {e}") + + # Load the queue from the extracted queue.json + loaded_count = self.load_queue_from_json(queue_json_path) + + # Clean up the temporary directory + shutil.rmtree(temp_dir) + + return loaded_count + + except Exception as e: + import traceback + print(f"Error loading queue from zip: {e}") + traceback.print_exc() + # Clean up the temporary directory if it exists + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + return 0 + + def _worker_loop(self): + """Worker thread that processes jobs from the queue""" + while True: + try: + # Get the next job ID from the queue + try: + job_id = self.queue.get(block=True, timeout=1.0) + except queue_module.Empty: + self._check_and_process_completed_grids() + continue + + with self.lock: + job = self.jobs.get(job_id) + if not job: + self.queue.task_done() + continue + + # Skip cancelled jobs + if job.status == JobStatus.CANCELLED: + self.queue.task_done() + continue + + # If it's a grid job, queue its children and mark it as running + if job.job_type == JobType.GRID: + print(f"Processing grid job {job.id}, adding {len(job.child_job_ids)} child jobs to queue.") + job.status = JobStatus.RUNNING # Mark the grid job as running + job.started_at = time.time() + # Add child jobs to the front of the queue + temp_queue = [] + while not self.queue.empty(): + temp_queue.append(self.queue.get()) + for child_id in reversed(job.child_job_ids): # Add in reverse to maintain order + self.queue.put(child_id) + for item in temp_queue: + self.queue.put(item) + + self.queue.task_done() + continue # Continue to the next iteration to process the first child job + + # If we're already processing a job, wait for it to complete + if self.is_processing: + # Check if this is the job that's already marked as running + # This can happen if the job was marked as running but not yet processed + if job.status == JobStatus.RUNNING and self.current_job and self.current_job.id == job_id: + print(f"Job {job_id} is already marked as running, processing it now") + # We'll process this job now + pass + else: + # Put the job back in the queue + self.queue.put(job_id) + self.queue.task_done() + time.sleep(0.1) # Small delay to prevent busy waiting + continue + + # Check if there's a previously running job that was interrupted + previously_running_job = None + for j in self.jobs.values(): + if j.status == JobStatus.RUNNING and j.id != job_id: + previously_running_job = j + break + + # If there's a previously running job, process it first + if previously_running_job: + print(f"Found previously running job {previously_running_job.id}, processing it first") + # Put the current job back in the queue + self.queue.put(job_id) + self.queue.task_done() + # Process the previously running job + job = previously_running_job + job_id = previously_running_job.id + + # Create a new stream for the resumed job and initialize progress_data + job.stream = AsyncStream() + job.progress_data = {} + + # Push an initial progress update to the stream + from diffusers_helper.gradio.progress_bar import make_progress_bar_html + job.stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Resuming job...')))) + + print(f"Starting job {job_id}, current job was {self.current_job.id if self.current_job else 'None'}") + job.status = JobStatus.RUNNING + job.started_at = time.time() + self.current_job = job + self.is_processing = True + + job_completed = False + + try: + if self.worker_function is None: + raise ValueError("Worker function not set. Call set_worker_function() first.") + + # Start the worker function with the job parameters + from diffusers_helper.thread_utils import async_run + print(f"Starting worker function for job {job_id}") + + # Clean up params for the worker function + worker_params = job.params.copy() + if 'end_frame_image_original' in worker_params: + del worker_params['end_frame_image_original'] + if 'end_frame_strength_original' in worker_params: + del worker_params['end_frame_strength_original'] + + async_run( + self.worker_function, + **worker_params, + job_stream=job.stream + ) + print(f"Worker function started for job {job_id}") + + # Process the results from the stream + output_filename = None + + # Track activity time for logging purposes + last_activity_time = time.time() + + while True: + # Check if job has been cancelled before processing next output + with self.lock: + if job.status == JobStatus.CANCELLED: + print(f"Job {job_id} was cancelled, breaking out of processing loop") + job_completed = True + break + + # Get current time for activity checks + current_time = time.time() + + # Check for inactivity (no output for a while) + if current_time - last_activity_time > 60: # 1 minute of inactivity + print(f"Checking if job {job_id} is still active...") + # Just a periodic check, don't break yet + + try: + # Try to get data from the queue with a non-blocking approach + flag, data = job.stream.output_queue.next() + + # Update activity time since we got some data + last_activity_time = time.time() + + if flag == 'file': + output_filename = data + with self.lock: + job.result = output_filename + + elif flag == 'progress': + preview, desc, html = data + with self.lock: + job.progress_data = { + 'preview': preview, + 'desc': desc, + 'html': html + } + + elif flag == 'end': + print(f"Received end signal for job {job_id}") + job_completed = True + break + + except IndexError: + # Queue is empty, wait a bit and try again + time.sleep(0.1) + continue + except Exception as e: + print(f"Error processing job output: {e}") + # Wait a bit before trying again + time.sleep(0.1) + continue + except Exception as e: + import traceback + traceback.print_exc() + print(f"Error processing job {job_id}: {e}") + with self.lock: + job.status = JobStatus.FAILED + job.error = str(e) + job.completed_at = time.time() + job_completed = True + + finally: + with self.lock: + # Make sure we properly clean up the job state + if job.status == JobStatus.RUNNING: + if job_completed: + job.status = JobStatus.COMPLETED + else: + # Something went wrong but we didn't mark it as completed + job.status = JobStatus.FAILED + job.error = "Job processing was interrupted" + + job.completed_at = time.time() + + print(f"Finishing job {job_id} with status {job.status}") + self.is_processing = False + + # Check if there's another job in the queue before setting current_job to None + # This helps prevent UI flashing when a job is cancelled + next_job_id = None + try: + # Peek at the next job without removing it from the queue + if not self.queue.empty(): + # We can't peek with the standard Queue, so we'll have to get creative + # Store the queue items temporarily + temp_queue = [] + while not self.queue.empty(): + item = self.queue.get() + temp_queue.append(item) + if next_job_id is None: + next_job_id = item + + # Put everything back + for item in temp_queue: + self.queue.put(item) + except Exception as e: + print(f"Error checking for next job: {e}") + + # After a job completes or is cancelled, always set current_job to None + self.current_job = None + + # The main loop's self.queue.get() will pick up the next available job. + # No need to explicitly find and start the next job here. + + self.queue.task_done() + + # Save the queue to JSON after job completion (outside the lock) + try: + self.save_queue_to_json() + except Exception as e: + print(f"Error saving queue to JSON after job completion: {e}") + + except Exception as e: + import traceback + traceback.print_exc() + print(f"Error in worker loop: {e}") + + # Make sure we reset processing state if there was an error + with self.lock: + self.is_processing = False + if self.current_job: + self.current_job.status = JobStatus.FAILED + self.current_job.error = f"Worker loop error: {str(e)}" + self.current_job.completed_at = time.time() + self.current_job = None + + time.sleep(0.5) # Prevent tight loop on error + + def _check_and_process_completed_grids(self): + """Check for completed grid jobs and process them.""" + with self.lock: + # Find all running grid jobs + running_grid_jobs = [job for job in self.jobs.values() if job.job_type == JobType.GRID and job.status == JobStatus.RUNNING] + + for grid_job in running_grid_jobs: + # Check if all child jobs are completed + child_jobs = [self.jobs.get(child_id) for child_id in grid_job.child_job_ids] + + if not all(child_jobs): + print(f"Warning: Some child jobs for grid {grid_job.id} not found.") + continue + + all_children_done = all(job.status in [JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED] for job in child_jobs) + + if all_children_done: + print(f"All child jobs for grid {grid_job.id} are done. Assembling grid.") + # Logic to assemble the grid + # This is a placeholder for the actual grid assembly logic + # For now, we'll just mark the grid job as completed. + + # Collect results from child jobs + child_results = [child.result for child in child_jobs if child.status == JobStatus.COMPLETED and child.result] + + if not child_results: + print(f"Grid job {grid_job.id} failed because no child jobs completed successfully.") + grid_job.status = JobStatus.FAILED + grid_job.error = "No child jobs completed successfully." + grid_job.completed_at = time.time() + continue + + # Placeholder for grid assembly. + # In a real implementation, you would use a tool like FFmpeg or MoviePy to stitch the videos. + # For this example, we'll just create a text file with the paths of the child videos. + try: + output_dir = grid_job.params.get("output_dir", "outputs") + grid_filename = os.path.join(output_dir, f"grid_{grid_job.id}.txt") + with open(grid_filename, "w") as f: + f.write(f"Grid for job: {grid_job.id}\n") + f.write("Child video paths:\n") + for result_path in child_results: + f.write(f"{result_path}\n") + + grid_job.result = grid_filename + grid_job.status = JobStatus.COMPLETED + print(f"Grid assembly for job {grid_job.id} complete. Result saved to {grid_filename}") + + except Exception as e: + print(f"Error during grid assembly for job {grid_job.id}: {e}") + grid_job.status = JobStatus.FAILED + grid_job.error = f"Grid assembly failed: {e}" + + grid_job.completed_at = time.time() + self.save_queue_to_json() diff --git a/modules/xy_plot_ui.py b/modules/xy_plot_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..068d59dbcd5b678d3c5bdea34d683b6c0464eb19 --- /dev/null +++ b/modules/xy_plot_ui.py @@ -0,0 +1,433 @@ +import gradio as gr +import numpy as np +import re +import itertools +import os +import imageio +import imageio.plugins.ffmpeg +import ffmpeg +from PIL import Image, ImageDraw, ImageFont + +from diffusers_helper.utils import generate_timestamp +from modules.video_queue import JobType + +# --- Helper Dictionaries & Functions --- + +xy_plot_axis_options = { + # "type": [ + # "dropdown(checkboxGroup), textbox or number", + # "empty if textbox, dtype if number, [] if dropdown", + # "standard values", + # "True if multi axis - like prompt replace, False is only on one axis - like steps" + # ], + "Nothing": ["nothing", "", "", True], + "Model type": ["dropdown", ["Original", "F1"], ["Original", "F1"], False], + "End frame influence": ["number", "float", "0.05-0.95[3]", False], + "Latent type": ["dropdown", ["Black", "White", "Noise", "Green Screen"], ["Black", "Noise"], False], + "Prompt add": ["textbox", "", "", True], + "Prompt replace": ["textbox", "", "", True], + "Blend sections": ["number", "int", "3-7 [3]", False], + "Steps": ["number", "int", "15-30 [3]", False], + "Seed": ["number", "int", "1000-10000 [3]", False], + "Use teacache": ["dropdown", [True, False], [True, False], False], + "TeaCache steps": ["number", "int", "5-25 [3]", False], + "TeaCache rel_l1_thresh": ["number", "float", "0.01-0.3 [3]", False], + "Use MagCache": ["dropdown", [True, False], [True, False], False], + "MagCache Threshold": ["number", "float", "0.01-1.0 [3]", False], + "MagCache Max Consecutive Skips": ["number", "int", "1-5 [3]", False], + "MagCache Retention Ratio": ["number", "float", "0.0-1.0 [3]", False], + # "CFG": ["number", "float", "", False], + "Distilled CFG Scale": ["number", "float", "5-15 [3]", False], + # "RS": ["number", "float", "", False], + # "Use weighted embeddings": ["dropdown", [True, False], [True, False], False], +} + +text_to_base_keys = { + "Model type": "model_type", + "End frame influence": "end_frame_strength_original", + "Latent type": "latent_type", + "Prompt add": "prompt", + "Prompt replace": "prompt", + "Blend sections": "blend_sections", + "Steps": "steps", + "Seed": "seed", + "Use teacache": "use_teacache", + "TeaCache steps":"teacache_num_steps", + "TeaCache rel_l1_thresh":"teacache_rel_l1_thresh", + "Use MagCache": "use_magcache", + "MagCache Threshold": "magcache_threshold", + "MagCache Max Consecutive Skips": "magcache_max_consecutive_skips", + "MagCache Retention Ratio": "magcache_retention_ratio", + "Latent window size": "latent_window_size", + # "CFG": "", + "Distilled CFG Scale": "gs", + # "RS": "", + # "Use weighted embeddings": "", +} + +def xy_plot_parse_input(text): + text = text.strip() + if ',' in text: + return [x.strip() for x in text.split(",")] + match = re.match(r'^\s*(-?\d*\.?\d*)\s*-\s*(-?\d*\.?\d*)\s*\[\s*(\d+)\s*\]$', text) + if match: + start, end, count = map(float, match.groups()) + result = np.linspace(start, end, int(count)) + if np.allclose(result, np.round(result)): + result = np.round(result).astype(int) + return result.tolist() + return [] + +def xy_plot_process( + job_queue, settings, # Added explicit dependencies + model_type, input_image, end_frame_image_original, + end_frame_strength_original, latent_type, + prompt, blend_sections, steps, total_second_length, + resolutionW, resolutionH, seed, randomize_seed, use_teacache, + teacache_num_steps, teacache_rel_l1_thresh, + use_magcache, magcache_threshold, magcache_max_consecutive_skips, magcache_retention_ratio, + latent_window_size, + cfg, gs, rs, gpu_memory_preservation, mp4_crf, + axis_x_switch, axis_x_value_text, axis_x_value_dropdown, + axis_y_switch, axis_y_value_text, axis_y_value_dropdown, + axis_z_switch, axis_z_value_text, axis_z_value_dropdown, + selected_loras, + *lora_slider_values + ): + # print(model_type, input_image, latent_type, + # prompt, blend_sections, steps, total_second_length, + # resolutionW, resolutionH, seed, randomize_seed, use_teacache, + # latent_window_size, cfg, gs, rs, gpu_memory_preservation, + # mp4_crf, + # axis_x_switch, axis_x_value_text, axis_x_value_dropdown, + # axis_y_switch, axis_y_value_text, axis_y_value_dropdown, + # axis_z_switch, axis_z_value_text, axis_z_value_dropdown, sep=", ") + if axis_x_switch == "Nothing" and axis_y_switch == "Nothing" and axis_z_switch == "Nothing": + return "Not selected any axis for plot", gr.update() + if (axis_x_switch == "Nothing" or axis_y_switch == "Nothing") and axis_z_switch != "Nothing": + return "For using Z axis, first use X and Y axis", gr.update() + if axis_x_switch == "Nothing" and axis_y_switch != "Nothing": + return "For using Y axis, first use X axis", gr.update() + if xy_plot_axis_options[axis_x_switch][0] == "dropdown" and len(axis_x_value_dropdown) < 1: + return "No values for axis X", gr.update() + if xy_plot_axis_options[axis_y_switch][0] == "dropdown" and len(axis_y_value_dropdown) < 1: + return "No values for axis Y", gr.update() + if xy_plot_axis_options[axis_z_switch][0] == "dropdown" and len(axis_z_value_dropdown) < 1: + return "No values for axis Z", gr.update() + if not xy_plot_axis_options[axis_x_switch][3]: + if axis_x_switch == axis_y_switch: + return "Axis type on X and Y axis are same, you can't do that generation.
Multi axis supported only for \"Prompt add\" and \"Prompt replace\".", gr.update() + if axis_x_switch == axis_z_switch: + return "Axis type on X and Z axis are same, you can't do that generation.
Multi axis supported only for \"Prompt add\" and \"Prompt replace\".", gr.update() + if not xy_plot_axis_options[axis_y_switch][3]: + if axis_y_switch == axis_z_switch: + return "Axis type on Y and Z axis are same, you can't do that generation.
Multi axis supported only for \"Prompt add\" and \"Prompt replace\".", gr.update() + + base_generator_vars = { + "model_type": model_type, + "input_image": input_image, + "end_frame_image": None, + "end_frame_strength": 1.0, + "input_video": None, + "end_frame_image_original": end_frame_image_original, + "end_frame_strength_original": end_frame_strength_original, + "prompt_text": prompt, + "n_prompt": "", + "seed": seed, + "total_second_length": total_second_length, + "latent_window_size": latent_window_size, + "steps": steps, + "cfg": cfg, + "gs": gs, + "rs": rs, + "use_teacache": use_teacache, + "teacache_num_steps": teacache_num_steps, + "teacache_rel_l1_thresh": teacache_rel_l1_thresh, + "use_magcache": use_magcache, + "magcache_threshold": magcache_threshold, + "magcache_max_consecutive_skips": magcache_max_consecutive_skips, + "magcache_retention_ratio": magcache_retention_ratio, + "has_input_image": True if input_image is not None else False, + "save_metadata_checked": True, + "blend_sections": blend_sections, + "latent_type": latent_type, + "selected_loras": selected_loras, + "resolutionW": resolutionW, + "resolutionH": resolutionH, + "lora_loaded_names": lora_names, + "lora_values": lora_slider_values + } + + def xy_plot_convert_values(type, value_textbox, value_dropdown): + retVal = [] + if type[0] == "dropdown": + retVal = value_dropdown + elif type[0] == "textbox": + retVal = xy_plot_parse_input(value_textbox) + elif type[0] == "number": + if type[1] == "int": + retVal = [int(float(x)) for x in xy_plot_parse_input(value_textbox)] + else: + retVal = [float(x) for x in xy_plot_parse_input(value_textbox)] + return retVal + prompt_replace_initial_values = {} + all_axis_values = { + axis_x_switch+" -> X": xy_plot_convert_values(xy_plot_axis_options[axis_x_switch], axis_x_value_text, axis_x_value_dropdown) + } + if axis_x_switch == "Prompt replace": + prompt_replace_initial_values["X"] = all_axis_values[axis_x_switch+" -> X"][0] + if prompt_replace_initial_values["X"] not in base_generator_vars["prompt_text"]: + return "Prompt for replacing in X axis not present in generation prompt", gr.update() + if axis_y_switch != "Nothing": + all_axis_values[axis_y_switch+" -> Y"] = xy_plot_convert_values(xy_plot_axis_options[axis_y_switch], axis_y_value_text, axis_y_value_dropdown) + if axis_y_switch == "Prompt replace": + prompt_replace_initial_values["Y"] = all_axis_values[axis_y_switch+" -> Y"][0] + if prompt_replace_initial_values["Y"] not in base_generator_vars["prompt_text"]: + return "Prompt for replacing in Y axis not present in generation prompt", gr.update() + if axis_z_switch != "Nothing": + all_axis_values[axis_z_switch+" -> Z"] = xy_plot_convert_values(xy_plot_axis_options[axis_z_switch], axis_z_value_text, axis_z_value_dropdown) + if axis_z_switch == "Prompt replace": + prompt_replace_initial_values["Z"] = all_axis_values[axis_z_switch+" -> Z"][0] + if prompt_replace_initial_values["Z"] not in base_generator_vars["prompt_text"]: + return "Prompt for replacing in Z axis not present in generation prompt", gr.update() + + active_axes = list(all_axis_values.keys()) + value_lists = [all_axis_values[axis] for axis in active_axes] + output_generator_vars = [] + + combintion_plot = itertools.product(*value_lists) + for combo in combintion_plot: + vars_copy = base_generator_vars.copy() + for axis, value in zip(active_axes, combo): + splitted_axis_name = axis.split(" -> ") + if splitted_axis_name[0] == "Prompt add": + vars_copy["prompt_text"] = vars_copy["prompt_text"] + " " + str(value) + elif splitted_axis_name[0] == "Prompt replace": + orig_copy_prompt_text = vars_copy["prompt_text"] + vars_copy["prompt_text"] = orig_copy_prompt_text.replace(prompt_replace_initial_values[splitted_axis_name[1]], str(value)) + else: + vars_copy[text_to_base_keys[splitted_axis_name[0]]] = value + vars_copy[splitted_axis_name[1]+"_axis_on_plot"] = str(value) + + worker_params = {k: v for k, v in vars_copy.items() if k not in ["X_axis_on_plot", "Y_axis_on_plot", "Z_axis_on_plot"]} + output_generator_vars.append(worker_params) + # print("----- BEFORE GENERATED VIDS VARS START -----") + # for v in output_generator_vars: + # print(v) + # print("------ BEFORE GENERATED VIDS VARS END ------") + + job_queue.add_job( + params=base_generator_vars, + job_type=JobType.GRID, + child_job_params_list=output_generator_vars + ) + return "Grid job added to the queue.", gr.update(visible=False) + # print("----- GENERATED VIDS VARS START -----") + # for v in output_generator_vars: + # print(v) + # print("------ GENERATED VIDS VARS END ------") + + # -------------------------- connect with settings -------------------------- + # Ensure settings is available in this scope or passed in. + # Assuming 'settings' object is available from create_interface's scope. + output_dir_setting = settings.get("output_dir", "outputs") + mp4_crf_setting = settings.get("mp4_crf", 16) # Default CRF if not in settings + # -------------------------- connect with settings -------------------------- + +def create_xy_plot_ui(lora_names, default_prompt, DUMMY_LORA_NAME): + """ + Creates the Gradio UI for the XY Plot functionality. + Returns a dictionary of key components to be used by the main interface. + """ + with gr.Group(visible=False) as xy_group: # The original was visible=False + with gr.Row(): + xy_plot_model_type = gr.Radio( + ["Original", "F1"], + label="Model Type", + value="F1", + info="Select which model to use for generation" + ) + with gr.Group(): + with gr.Row(): + with gr.Column(scale=1): + xy_plot_input_image = gr.Image( + sources='upload', + type="numpy", + label="Image (optional)", + height=420, + image_mode="RGB", + elem_classes="contain-image" + ) + with gr.Column(scale=1): + xy_plot_end_frame_image_original = gr.Image( + sources='upload', + type="numpy", + label="End Frame (Optional)", + height=420, + elem_classes="contain-image", + image_mode="RGB", + show_download_button=False, + show_label=True, + container=True + ) + with gr.Group(): + xy_plot_end_frame_strength_original = gr.Slider( + label="End Frame Influence", + minimum=0.05, + maximum=1.0, + value=1.0, + step=0.05, + info="Controls how strongly the end frame guides the generation. 1.0 is full influence." + ) + with gr.Accordion("Latent Image Options", open=False): + xy_plot_latent_type = gr.Dropdown( + ["Black", "White", "Noise", "Green Screen"], + label="Latent Image", + value="Black", + info="Used as a starting point if no image is provided" + ) + xy_plot_prompt = gr.Textbox(label="Prompt", value=default_prompt) + with gr.Accordion("Prompt Parameters", open=False): + xy_plot_blend_sections = gr.Slider( + minimum=0, maximum=10, value=4, step=1, + label="Number of sections to blend between prompts" + ) + with gr.Accordion("Generation Parameters", open=True): + with gr.Row(): + xy_plot_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=5, step=1) + xy_plot_total_second_length = gr.Slider(label="Video Length (Seconds)", minimum=0.1, maximum=120, value=1, step=0.1) + with gr.Row(): + xy_plot_seed = gr.Number(label="Seed", value=31337, precision=0) + xy_plot_randomize_seed = gr.Checkbox(label="Randomize", value=False, info="Generate a new random seed for each job") + with gr.Row("LoRAs"): + xy_plot_lora_selector = gr.Dropdown( + choices=lora_names, + label="Select LoRAs to Load", + multiselect=True, + value=[], + info="Select one or more LoRAs to use for this job" + ) + xy_plot_lora_sliders = {} + for lora in lora_names: + xy_plot_lora_sliders[lora] = gr.Slider( + minimum=0.0, maximum=2.0, value=1.0, step=0.01, + label=f"{lora} Weight", visible=False, interactive=True + ) + with gr.Accordion("Advanced Parameters", open=False): + with gr.Row("TeaCache"): + xy_plot_use_teacache = gr.Checkbox(label='Use TeaCache', value=True, info='Faster speed, but often makes hands and fingers slightly worse.') + xy_plot_teacache_num_steps = gr.Slider(label="TeaCache steps", minimum=1, maximum=50, step=1, value=25, visible=True, info='How many intermediate sections to keep in the cache') + xy_plot_teacache_rel_l1_thresh = gr.Slider(label="TeaCache rel_l1_thresh", minimum=0.01, maximum=1.0, step=0.01, value=0.15, visible=True, info='Relative L1 Threshold') + with gr.Row("MagCache"): + xy_plot_use_magcache = gr.Checkbox(label='Use MagCache', value=False, info='Faster speed, but may introduce artifacts. Uses pre-calibrated ratios.') + xy_plot_magcache_threshold = gr.Slider(label="MagCache Threshold", minimum=0.01, maximum=1.0, step=0.01, value=0.1, visible=False, info='Error tolerance for skipping steps. Lower = more skips, higher = fewer skips.') + xy_plot_magcache_max_consecutive_skips = gr.Slider(label="MagCache Max Consecutive Skips", minimum=1, maximum=10, step=1, value=2, visible=False, info='Maximum number of consecutive steps that can be skipped.') + xy_plot_magcache_retention_ratio = gr.Slider(label="MagCache Retention Ratio", minimum=0.0, maximum=1.0, step=0.01, value=0.25, visible=False, info='Ratio of initial steps to always calculate (not skip).') + + # Mutual exclusivity logic for TeaCache and MagCache in XY Plot UI + xy_plot_use_teacache.change(lambda enabled: (gr.update(visible=enabled), gr.update(visible=enabled), gr.update(value=not enabled)), inputs=xy_plot_use_teacache, outputs=[xy_plot_teacache_num_steps, xy_plot_teacache_rel_l1_thresh, xy_plot_use_magcache]) + xy_plot_use_magcache.change(lambda enabled: (gr.update(visible=enabled), gr.update(visible=enabled), gr.update(visible=enabled), gr.update(value=not enabled)), inputs=xy_plot_use_magcache, outputs=[xy_plot_magcache_threshold, xy_plot_magcache_max_consecutive_skips, xy_plot_magcache_retention_ratio, xy_plot_use_teacache]) + + xy_plot_latent_window_size = gr.Slider(label="Latent Window Size", minimum=1, maximum=33, value=9, step=1, visible=True, info='Change at your own risk, very experimental') + xy_plot_cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=1.0, step=0.01, visible=False) + xy_plot_gs = gr.Slider(label="Distilled CFG Scale", minimum=1.0, maximum=32.0, value=10.0, step=0.01) + xy_plot_rs = gr.Slider(label="CFG Re-Scale", minimum=0.0, maximum=1.0, value=0.0, step=0.01, visible=False) + xy_plot_gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=1, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.") + with gr.Accordion("Output Parameters", open=False): + xy_plot_mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ") + with gr.Accordion("Plot Parameters", open=True): + def xy_plot_axis_change(updated_value_type): + if xy_plot_axis_options[updated_value_type][0] == "textbox" or xy_plot_axis_options[updated_value_type][0] == "number": + return gr.update(visible=True, value=xy_plot_axis_options[updated_value_type][2]), gr.update(visible=False, value=[], choices=[]) + elif xy_plot_axis_options[updated_value_type][0] == "dropdown": + return gr.update(visible=False), gr.update(visible=True, value=xy_plot_axis_options[updated_value_type][2], choices=xy_plot_axis_options[updated_value_type][1]) + else: + return gr.update(visible=False), gr.update(visible=False, value=[], choices=[]) + with gr.Row(): + xy_plot_axis_x_switch = gr.Dropdown(label="X axis type for plotting", choices=list(xy_plot_axis_options.keys())) + xy_plot_axis_x_value_text = gr.Textbox(label="X axis comma separated text", visible=False) + xy_plot_axis_x_value_dropdown = gr.CheckboxGroup(label="X axis values", visible=False) #, multiselect=True) + with gr.Row(): + xy_plot_axis_y_switch = gr.Dropdown(label="Y axis type for plotting", choices=list(xy_plot_axis_options.keys())) + xy_plot_axis_y_value_text = gr.Textbox(label="Y axis comma separated text", visible=False) + xy_plot_axis_y_value_dropdown = gr.CheckboxGroup(label="Y axis values", visible=False) #, multiselect=True) + with gr.Row(visible=False): # not implemented Z axis + xy_plot_axis_z_switch = gr.Dropdown(label="Z axis type for plotting", choices=list(xy_plot_axis_options.keys())) + xy_plot_axis_z_value_text = gr.Textbox(label="Z axis comma separated text", visible=False) + xy_plot_axis_z_value_dropdown = gr.CheckboxGroup(label="Z axis values", visible=False) #, multiselect=True) + + xy_plot_status = gr.HTML("") + xy_plot_output = gr.Video(autoplay=True, loop=True, sources=[], height=256, visible=False) + # --- ADD THE PROCESS BUTTON HERE --- + # This button is logically part of the XY plot group but will be controlled + # from interface.py. We place it here so it's encapsulated. + xy_plot_process_btn = gr.Button("Submit", visible=False) + + # --- Internal Event Handlers --- + xy_plot_use_teacache.change(lambda enabled: (gr.update(visible=enabled), gr.update(visible=enabled)), inputs=xy_plot_use_teacache, outputs=[xy_plot_teacache_num_steps, xy_plot_teacache_rel_l1_thresh]) + xy_plot_axis_x_switch.change(fn=xy_plot_axis_change, inputs=[xy_plot_axis_x_switch], outputs=[xy_plot_axis_x_value_text, xy_plot_axis_x_value_dropdown]) + xy_plot_axis_y_switch.change(fn=xy_plot_axis_change, inputs=[xy_plot_axis_y_switch], outputs=[xy_plot_axis_y_value_text, xy_plot_axis_y_value_dropdown]) + xy_plot_axis_z_switch.change(fn=xy_plot_axis_change, inputs=[xy_plot_axis_z_switch], outputs=[xy_plot_axis_z_value_text, xy_plot_axis_z_value_dropdown]) + + def xy_plot_update_lora_sliders(selected_loras): + updates = [] + actual_selected_loras_for_display = [lora for lora in selected_loras if lora != DUMMY_LORA_NAME] + updates.append(gr.update(value=actual_selected_loras_for_display)) + + for lora_name_key in lora_names: + if lora_name_key == DUMMY_LORA_NAME: + updates.append(gr.update(visible=False)) + else: + updates.append(gr.update(visible=(lora_name_key in actual_selected_loras_for_display))) + return updates + + xy_plot_lora_selector.change( + fn=xy_plot_update_lora_sliders, + inputs=[xy_plot_lora_selector], + outputs=[xy_plot_lora_selector] + [xy_plot_lora_sliders[lora] for lora in lora_names if lora in xy_plot_lora_sliders] + ) + + # --- Component Dictionary for Export --- + components = { + "group": xy_group, + "status": xy_plot_status, + "output": xy_plot_output, + "process_btn": xy_plot_process_btn, + # --- Inputs for the process button --- + "model_type": xy_plot_model_type, + "input_image": xy_plot_input_image, + "end_frame_image_original": xy_plot_end_frame_image_original, + "end_frame_strength_original": xy_plot_end_frame_strength_original, + "latent_type": xy_plot_latent_type, + "prompt": xy_plot_prompt, + "blend_sections": xy_plot_blend_sections, + "steps": xy_plot_steps, + "total_second_length": xy_plot_total_second_length, + "seed": xy_plot_seed, + "randomize_seed": xy_plot_randomize_seed, + "use_teacache": xy_plot_use_teacache, + "teacache_num_steps": xy_plot_teacache_num_steps, + "teacache_rel_l1_thresh": xy_plot_teacache_rel_l1_thresh, + "use_magcache": xy_plot_use_magcache, + "magcache_threshold": xy_plot_magcache_threshold, + "magcache_max_consecutive_skips": xy_plot_magcache_max_consecutive_skips, + "magcache_retention_ratio": xy_plot_magcache_retention_ratio, + "latent_window_size": xy_plot_latent_window_size, + "cfg": xy_plot_cfg, + "gs": xy_plot_gs, + "rs": xy_plot_rs, + "gpu_memory_preservation": xy_plot_gpu_memory_preservation, + "mp4_crf": xy_plot_mp4_crf, + "axis_x_switch": xy_plot_axis_x_switch, + "axis_x_value_text": xy_plot_axis_x_value_text, + "axis_x_value_dropdown": xy_plot_axis_x_value_dropdown, + "axis_y_switch": xy_plot_axis_y_switch, + "axis_y_value_text": xy_plot_axis_y_value_text, + "axis_y_value_dropdown": xy_plot_axis_y_value_dropdown, + "axis_z_switch": xy_plot_axis_z_switch, + "axis_z_value_text": xy_plot_axis_z_value_text, + "axis_z_value_dropdown": xy_plot_axis_z_value_dropdown, + "lora_selector": xy_plot_lora_selector, + "lora_sliders": xy_plot_lora_sliders, + } + return components \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a74de834f72ac76caee564ee5bae0af974eb005 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,31 @@ +accelerate==1.6.0 +av==12.1.0 +decord +diffusers==0.33.1 +einops +ffmpeg-python==0.2.0 +gradio==5.25.2 +imageio-ffmpeg==0.4.8 +imageio==2.31.1 +jinja2>=3.1.2 +numpy==1.26.2 +opencv-contrib-python +peft +pillow==11.1.0 +requests==2.31.0 +safetensors +scipy==1.12.0 +sentencepiece==0.2.0 +torchsde==0.2.6 +tqdm +timm +transformers==4.46.2 + +# for toolbox +basicsr +# basicsr-fixed +devicetorch +facexlib>=0.2.5 +gfpgan>=1.3.5 +psutil +realesrgan \ No newline at end of file diff --git a/run.bat b/run.bat new file mode 100644 index 0000000000000000000000000000000000000000..e4c2b5e8fee4717ad2c9309d15498740361e848e --- /dev/null +++ b/run.bat @@ -0,0 +1,22 @@ +@echo off +echo Starting FramePack-Studio... + +REM Check if Python is installed (basic check) +where python >nul 2>&1 +if %errorlevel% neq 0 ( + echo Error: Python is not installed or not in your PATH. Cannot run studio.py. + goto end +) + +if exist "%cd%/venv/Scripts/python.exe" ( + +"%cd%/venv/Scripts/python.exe" studio.py + +) else ( + +echo Error: Virtual Environment for Python not found. Did you install correctly? +goto end + +) + +:end \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..25d0c870ec0e05db7053427a62df97e90c623b80 --- /dev/null +++ b/run.sh @@ -0,0 +1,15 @@ +echo Starting FramePack-Studio... + +if [ -z "$(command -v python)" ]; then + echo "Did not find a Python binary. Exiting." + exit 1 +fi + +if [ ! -f "./venv/bin/activate" ]; then + echo "Did not find a Python virtual environment. Exiting." + exit 1 +fi + +source venv/bin/activate + +python studio.py "$@" \ No newline at end of file diff --git a/studio.py b/studio.py new file mode 100644 index 0000000000000000000000000000000000000000..0289afff74808f6f69c67f3cadef08f0926ca400 --- /dev/null +++ b/studio.py @@ -0,0 +1,705 @@ +from diffusers_helper.hf_login import login + +import json +import os +import shutil +from pathlib import PurePath, Path +import time +import argparse +import traceback +import einops +import numpy as np +import torch +import datetime + +# Version information +from modules.version import APP_VERSION + +# Set environment variables +os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download'))) +os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning + + + +import gradio as gr +from PIL import Image +from PIL.PngImagePlugin import PngInfo +from diffusers import AutoencoderKLHunyuanVideo +from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer +from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake +from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp +from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked +from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan +from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete +from diffusers_helper.thread_utils import AsyncStream +from diffusers_helper.gradio.progress_bar import make_progress_bar_html +from transformers import SiglipImageProcessor, SiglipVisionModel +from diffusers_helper.clip_vision import hf_clip_vision_encode +from diffusers_helper.bucket_tools import find_nearest_bucket +from diffusers_helper import lora_utils +from diffusers_helper.lora_utils import load_lora, unload_all_loras + +# Import model generators +from modules.generators import create_model_generator + +# Global cache for prompt embeddings +prompt_embedding_cache = {} +# Import from modules +from modules.video_queue import VideoJobQueue, JobStatus +from modules.prompt_handler import parse_timestamped_prompt +from modules.interface import create_interface, format_queue_status +from modules.settings import Settings +from modules import DUMMY_LORA_NAME # Import the constant +from modules.pipelines.metadata_utils import create_metadata +from modules.pipelines.worker import worker + +# Try to suppress annoyingly persistent Windows asyncio proactor errors +if os.name == 'nt': # Windows only + import asyncio + from functools import wraps + + # Replace the problematic proactor event loop with selector event loop + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + # Patch the base transport's close method + def silence_event_loop_closed(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except RuntimeError as e: + if str(e) != 'Event loop is closed': + raise + return wrapper + + # Apply the patch + if hasattr(asyncio.proactor_events._ProactorBasePipeTransport, '_call_connection_lost'): + asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost = silence_event_loop_closed( + asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost) + +# ADDED: Debug function to verify LoRA state +def verify_lora_state(transformer, label=""): + """Debug function to verify the state of LoRAs in a transformer model""" + if transformer is None: + print(f"[{label}] Transformer is None, cannot verify LoRA state") + return + + has_loras = False + if hasattr(transformer, 'peft_config'): + adapter_names = list(transformer.peft_config.keys()) if transformer.peft_config else [] + if adapter_names: + has_loras = True + print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}") + else: + print(f"[{label}] Transformer has no LoRAs in peft_config") + else: + print(f"[{label}] Transformer has no peft_config attribute") + + # Check for any LoRA modules + for name, module in transformer.named_modules(): + if hasattr(module, 'lora_A') and module.lora_A: + has_loras = True + # print(f"[{label}] Found lora_A in module {name}") + if hasattr(module, 'lora_B') and module.lora_B: + has_loras = True + # print(f"[{label}] Found lora_B in module {name}") + + if not has_loras: + print(f"[{label}] No LoRA components found in transformer") + + +parser = argparse.ArgumentParser() +parser.add_argument('--share', action='store_true') +parser.add_argument("--server", type=str, default='0.0.0.0') +parser.add_argument("--port", type=int, required=False) +parser.add_argument("--inbrowser", action='store_true') +parser.add_argument("--lora", type=str, default=None, help="Lora path (comma separated for multiple)") +parser.add_argument("--offline", action='store_true', help="Run in offline mode") +args = parser.parse_args() + +print(args) + +if args.offline: + print("Offline mode enabled.") + os.environ['HF_HUB_OFFLINE'] = '1' +else: + if 'HF_HUB_OFFLINE' in os.environ: + del os.environ['HF_HUB_OFFLINE'] + +free_mem_gb = get_cuda_free_memory_gb(gpu) +high_vram = free_mem_gb > 60 + +print(f'Free VRAM {free_mem_gb} GB') +print(f'High-VRAM Mode: {high_vram}') + +# Load models +text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu() +text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu() +tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer') +tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2') +vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu() + +feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor') +image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu() + +# Initialize model generator placeholder +current_generator = None # Will hold the currently active model generator + +# Load models based on VRAM availability later + +# Configure models +vae.eval() +text_encoder.eval() +text_encoder_2.eval() +image_encoder.eval() + +if not high_vram: + vae.enable_slicing() + vae.enable_tiling() + + +vae.to(dtype=torch.float16) +image_encoder.to(dtype=torch.float16) +text_encoder.to(dtype=torch.float16) +text_encoder_2.to(dtype=torch.float16) + +vae.requires_grad_(False) +text_encoder.requires_grad_(False) +text_encoder_2.requires_grad_(False) +image_encoder.requires_grad_(False) + +# Create lora directory if it doesn't exist +lora_dir = os.path.join(os.path.dirname(__file__), 'loras') +os.makedirs(lora_dir, exist_ok=True) + +# Initialize LoRA support - moved scanning after settings load +lora_names = [] +lora_values = [] # This seems unused for population, might be related to weights later + +script_dir = os.path.dirname(os.path.abspath(__file__)) + +# Define default LoRA folder path relative to the script directory (used if setting is missing) +default_lora_folder = os.path.join(script_dir, "loras") +os.makedirs(default_lora_folder, exist_ok=True) # Ensure default exists + +if not high_vram: + # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster + DynamicSwapInstaller.install_model(text_encoder, device=gpu) +else: + text_encoder.to(gpu) + text_encoder_2.to(gpu) + image_encoder.to(gpu) + vae.to(gpu) + +stream = AsyncStream() + +outputs_folder = './outputs/' +os.makedirs(outputs_folder, exist_ok=True) + +# Initialize settings +settings = Settings() + +# NEW: auto-cleanup on start-up option in Settings +if settings.get("auto_cleanup_on_startup", False): + print("--- Running Automatic Startup Cleanup ---") + + # Import the processor instance + from modules.toolbox_app import tb_processor + + # Call the single cleanup function and print its summary. + cleanup_summary = tb_processor.tb_clear_temporary_files() + print(f"{cleanup_summary}") # This cleaner print handles the multiline string well + + print("--- Startup Cleanup Complete ---") + +# --- Populate LoRA names AFTER settings are loaded --- +lora_folder_from_settings: str = settings.get("lora_dir", default_lora_folder) # Use setting, fallback to default +print(f"Scanning for LoRAs in: {lora_folder_from_settings}") +if os.path.isdir(lora_folder_from_settings): + try: + for root, _, files in os.walk(lora_folder_from_settings): + for file in files: + if file.endswith('.safetensors') or file.endswith('.pt'): + lora_relative_path = os.path.relpath(os.path.join(root, file), lora_folder_from_settings) + lora_name = str(PurePath(lora_relative_path).with_suffix('')) + lora_names.append(lora_name) + print(f"Found LoRAs: {lora_names}") + # Temp solution for only 1 lora + if len(lora_names) == 1: + lora_names.append(DUMMY_LORA_NAME) + except Exception as e: + print(f"Error scanning LoRA directory '{lora_folder_from_settings}': {e}") +else: + print(f"LoRA directory not found: {lora_folder_from_settings}") +# --- End LoRA population --- + + +# Create job queue +job_queue = VideoJobQueue() + + + +# Function to load a LoRA file +def load_lora_file(lora_file: str | PurePath): + if not lora_file: + return None, "No file selected" + + try: + # Get the filename from the path + lora_path = PurePath(lora_file) + lora_name = lora_path.name + + # Copy the file to the lora directory + lora_dest = PurePath(lora_dir, lora_path) + import shutil + shutil.copy(lora_file, lora_dest) + + # Load the LoRA + global current_generator, lora_names + if current_generator is None: + return None, "Error: No model loaded to apply LoRA to. Generate something first." + + # Unload any existing LoRAs first + current_generator.unload_loras() + + # Load the single LoRA + selected_loras = [lora_path.stem] + current_generator.load_loras(selected_loras, lora_dir, selected_loras) + + # Add to lora_names if not already there + lora_base_name = lora_path.stem + if lora_base_name not in lora_names: + lora_names.append(lora_base_name) + + # Get the current device of the transformer + device = next(current_generator.transformer.parameters()).device + + # Move all LoRA adapters to the same device as the base model + current_generator.move_lora_adapters_to_device(device) + + print(f"Loaded LoRA: {lora_name} to {current_generator.get_model_name()} model") + + return gr.update(choices=lora_names), f"Successfully loaded LoRA: {lora_name}" + except Exception as e: + print(f"Error loading LoRA: {e}") + return None, f"Error loading LoRA: {e}" + +@torch.no_grad() +def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device): + """ + Retrieves prompt embeddings from cache or encodes them if not found. + Stores encoded embeddings (on CPU) in the cache. + Returns embeddings moved to the target_device. + """ + if prompt in prompt_embedding_cache: + print(f"Cache hit for prompt: {prompt[:60]}...") + llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt] + # Move cached embeddings (from CPU) to the target device + llama_vec = llama_vec_cpu.to(target_device) + llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None + clip_l_pooler = clip_l_pooler_cpu.to(target_device) + return llama_vec, llama_attention_mask, clip_l_pooler + else: + print(f"Cache miss for prompt: {prompt[:60]}...") + llama_vec, clip_l_pooler = encode_prompt_conds( + prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2 + ) + llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512) + # Store CPU copies in cache + prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu()) + # Return embeddings already on the target device (as encode_prompt_conds uses the model's device) + return llama_vec, llama_attention_mask, clip_l_pooler + +# Set the worker function for the job queue - using the imported worker from modules/pipelines/worker.py +job_queue.set_worker_function(worker) + + +def process( + model_type, + input_image, + end_frame_image, # NEW + end_frame_strength, # NEW + prompt_text, + n_prompt, + seed, + total_second_length, + latent_window_size, + steps, + cfg, + gs, + rs, + use_teacache, + teacache_num_steps, + teacache_rel_l1_thresh, + use_magcache, + magcache_threshold, + magcache_max_consecutive_skips, + magcache_retention_ratio, + blend_sections, + latent_type, + clean_up_videos, + selected_loras, + resolutionW, + resolutionH, + input_image_path, + combine_with_source, + num_cleaned_frames, + *lora_args, + save_metadata_checked=True, # NEW: Parameter to control metadata saving + ): + + # Create a blank black image if no + # Create a default image based on the selected latent_type + has_input_image = True + if input_image is None: + has_input_image = False + default_height, default_width = resolutionH, resolutionW + if latent_type == "White": + # Create a white image + input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255 + print("No input image provided. Using a blank white image.") + + elif latent_type == "Noise": + # Create a noise image + input_image = np.random.randint(0, 256, (default_height, default_width, 3), dtype=np.uint8) + print("No input image provided. Using a random noise image.") + + elif latent_type == "Green Screen": + # Create a green screen image with standard chroma key green (0, 177, 64) + input_image = np.zeros((default_height, default_width, 3), dtype=np.uint8) + input_image[:, :, 1] = 177 # Green channel + input_image[:, :, 2] = 64 # Blue channel + # Red channel remains 0 + print("No input image provided. Using a standard chroma key green screen.") + + else: # Default to "Black" or any other value + # Create a black image + input_image = np.zeros((default_height, default_width, 3), dtype=np.uint8) + print(f"No input image provided. Using a blank black image (latent_type: {latent_type}).") + + + # Handle input files - copy to input_files_dir to prevent them from being deleted by temp cleanup + input_files_dir = settings.get("input_files_dir") + os.makedirs(input_files_dir, exist_ok=True) + + # Process input image (if it's a file path) + input_image_path = None + if isinstance(input_image, str) and os.path.exists(input_image): + # It's a file path, copy it to input_files_dir + filename = os.path.basename(input_image) + input_image_path = os.path.join(input_files_dir, f"{generate_timestamp()}_{filename}") + try: + shutil.copy2(input_image, input_image_path) + print(f"Copied input image to {input_image_path}") + # For Video model, we'll use the path + if model_type == "Video": + input_image = input_image_path + except Exception as e: + print(f"Error copying input image: {e}") + + # Process end frame image (if it's a file path) + end_frame_image_path = None + if isinstance(end_frame_image, str) and os.path.exists(end_frame_image): + # It's a file path, copy it to input_files_dir + filename = os.path.basename(end_frame_image) + end_frame_image_path = os.path.join(input_files_dir, f"{generate_timestamp()}_{filename}") + try: + shutil.copy2(end_frame_image, end_frame_image_path) + print(f"Copied end frame image to {end_frame_image_path}") + except Exception as e: + print(f"Error copying end frame image: {e}") + + # Extract lora_loaded_names from lora_args + lora_loaded_names = lora_args[0] if lora_args and len(lora_args) > 0 else [] + lora_values = lora_args[1:] if lora_args and len(lora_args) > 1 else [] + + # Create job parameters + job_params = { + 'model_type': model_type, + 'input_image': input_image.copy() if hasattr(input_image, 'copy') else input_image, # Handle both image arrays and video paths + 'end_frame_image': end_frame_image.copy() if end_frame_image is not None else None, + 'end_frame_strength': end_frame_strength, + 'prompt_text': prompt_text, + 'n_prompt': n_prompt, + 'seed': seed, + 'total_second_length': total_second_length, + 'latent_window_size': latent_window_size, + 'latent_type': latent_type, + 'steps': steps, + 'cfg': cfg, + 'gs': gs, + 'rs': rs, + 'blend_sections': blend_sections, + 'use_teacache': use_teacache, + 'teacache_num_steps': teacache_num_steps, + 'teacache_rel_l1_thresh': teacache_rel_l1_thresh, + 'use_magcache': use_magcache, + 'magcache_threshold': magcache_threshold, + 'magcache_max_consecutive_skips': magcache_max_consecutive_skips, + 'magcache_retention_ratio': magcache_retention_ratio, + 'selected_loras': selected_loras, + 'has_input_image': has_input_image, + 'output_dir': settings.get("output_dir"), + 'metadata_dir': settings.get("metadata_dir"), + 'input_files_dir': input_files_dir, # Add input_files_dir to job parameters + 'input_image_path': input_image_path, # Add the path to the copied input image + 'end_frame_image_path': end_frame_image_path, # Add the path to the copied end frame image + 'resolutionW': resolutionW, # Add resolution parameter + 'resolutionH': resolutionH, + 'lora_loaded_names': lora_loaded_names, + 'combine_with_source': combine_with_source, # Add combine_with_source parameter + 'num_cleaned_frames': num_cleaned_frames, + 'save_metadata_checked': save_metadata_checked, # NEW: Add save_metadata_checked parameter + } + + # Print teacache parameters for debugging + print(f"Teacache parameters: use_teacache={use_teacache}, teacache_num_steps={teacache_num_steps}, teacache_rel_l1_thresh={teacache_rel_l1_thresh}") + + # Add LoRA values if provided - extract them from the tuple + if lora_values: + # Convert tuple to list + lora_values_list = list(lora_values) + job_params['lora_values'] = lora_values_list + + # Add job to queue + job_id = job_queue.add_job(job_params) + + # Set the generation_type attribute on the job object directly + job = job_queue.get_job(job_id) + if job: + job.generation_type = model_type # Set generation_type to model_type for display in queue + print(f"Added job {job_id} to queue") + + queue_status = update_queue_status() + # Return immediately after adding to queue + # Return separate updates for start_button and end_button to prevent cross-contamination + return None, job_id, None, '', f'Job added to queue. Job ID: {job_id}', gr.update(value="🚀 Add to Queue", interactive=True), gr.update(value="❌ Cancel Current Job", interactive=True) + + + +def end_process(): + """Cancel the current running job and update the queue status""" + print("Cancelling current job") + with job_queue.lock: + if job_queue.current_job: + job_id = job_queue.current_job.id + print(f"Cancelling job {job_id}") + + # Send the end signal to the job's stream + if job_queue.current_job.stream: + job_queue.current_job.stream.input_queue.push('end') + + # Mark the job as cancelled + job_queue.current_job.status = JobStatus.CANCELLED + job_queue.current_job.completed_at = time.time() # Set completion time + + # Force an update to the queue status + return update_queue_status() + + +def update_queue_status(): + """Update queue status and refresh job positions""" + jobs = job_queue.get_all_jobs() + for job in jobs: + if job.status == JobStatus.PENDING: + job.queue_position = job_queue.get_queue_position(job.id) + + # Make sure to update current running job info + if job_queue.current_job: + # Make sure the running job is showing status = RUNNING + job_queue.current_job.status = JobStatus.RUNNING + + # Update the toolbar stats + pending_count = 0 + running_count = 0 + completed_count = 0 + + for job in jobs: + if hasattr(job, 'status'): + status = str(job.status) + if status == "JobStatus.PENDING": + pending_count += 1 + elif status == "JobStatus.RUNNING": + running_count += 1 + elif status == "JobStatus.COMPLETED": + completed_count += 1 + + return format_queue_status(jobs) + + +def monitor_job(job_id=None): + """ + Monitor a specific job and update the UI with the latest video segment as soon as it's available. + If no job_id is provided, check if there's a current job in the queue. + ALWAYS shows the current running job, regardless of the job_id provided. + """ + last_video = None # Track the last video file shown + last_job_status = None # Track the previous job status to detect status changes + last_progress_update_time = time.time() # Track when we last updated the progress + last_preview = None # Track the last preview image shown + force_update = True # Force an update on first iteration + + # Flag to indicate we're waiting for a job transition + waiting_for_transition = False + transition_start_time = None + max_transition_wait = 5.0 # Maximum time to wait for transition in seconds + + def get_preview_updates(preview_value): + """Create preview updates that respect the latents_display_top setting""" + display_top = settings.get("latents_display_top", False) + if display_top: + # Top display enabled: update top preview with value, don't update right preview + return gr.update(), preview_value if preview_value is not None else gr.update() + else: + # Right column display: update right preview with value, don't update top preview + return preview_value if preview_value is not None else gr.update(), gr.update() + + while True: + # ALWAYS check if there's a current running job that's different from our tracked job_id + with job_queue.lock: + current_job = job_queue.current_job + if current_job and current_job.id != job_id and current_job.status == JobStatus.RUNNING: + # Always switch to the current running job + job_id = current_job.id + waiting_for_transition = False + force_update = True + # Yield a temporary update to show we're switching jobs + right_preview, top_preview = get_preview_updates(None) + yield last_video, right_preview, top_preview, '', 'Switching to current job...', gr.update(interactive=True), gr.update(value="❌ Cancel Current Job", visible=True) + continue + + # Check if we're waiting for a job transition + if waiting_for_transition: + current_time = time.time() + # If we've been waiting too long, stop waiting + if current_time - transition_start_time > max_transition_wait: + waiting_for_transition = False + + # Check one more time for a current job + with job_queue.lock: + current_job = job_queue.current_job + if current_job and current_job.status == JobStatus.RUNNING: + # Switch to whatever job is currently running + job_id = current_job.id + force_update = True + right_preview, top_preview = get_preview_updates(None) + yield last_video, right_preview, top_preview, '', 'Switching to current job...', gr.update(interactive=True), gr.update(value="❌ Cancel Current Job", visible=True) + continue + else: + # If still waiting, sleep briefly and continue + time.sleep(0.1) + continue + + job = job_queue.get_job(job_id) + if not job: + # Correctly yield 7 items for the startup/no-job case + # This ensures the status text goes to the right component and the buttons are set correctly. + yield None, None, None, 'No job ID provided', '', gr.update(value="🚀 Add to Queue", interactive=True, visible=True), gr.update(interactive=False, visible=False) + return + + # If a new video file is available, yield it immediately + if job.result and job.result != last_video: + last_video = job.result + # You can also update preview/progress here if desired + right_preview, top_preview = get_preview_updates(None) + yield last_video, right_preview, top_preview, '', '', gr.update(interactive=True), gr.update(interactive=True) + + # Handle job status and progress + if job.status == JobStatus.PENDING: + position = job_queue.get_queue_position(job_id) + right_preview, top_preview = get_preview_updates(None) + yield last_video, right_preview, top_preview, '', f'Waiting in queue. Position: {position}', gr.update(interactive=True), gr.update(interactive=True) + + elif job.status == JobStatus.RUNNING: + # Only reset the cancel button when a job transitions from another state to RUNNING + # This ensures we don't reset the button text during cancellation + if last_job_status != JobStatus.RUNNING: + # Check if the button text is already "Cancelling..." - if so, don't change it + # This prevents the button from changing back to "Cancel Current Job" during cancellation + button_update = gr.update(interactive=True, value="❌ Cancel Current Job", visible=True) + else: + # Keep current text and state - important to not override "Cancelling..." text + button_update = gr.update(interactive=True, visible=True) + + # Check if we have progress data and if it's time to update + current_time = time.time() + update_needed = force_update or (current_time - last_progress_update_time > 0.05) # More frequent updates + + # Always check for progress data, even if we don't have a preview yet + if job.progress_data and update_needed: + preview = job.progress_data.get('preview') + desc = job.progress_data.get('desc', '') + html = job.progress_data.get('html', '') + + # Only update the preview if it has changed or we're forcing an update + # Ensure all components get an update + current_preview_value = job.progress_data.get('preview') if job.progress_data else None + current_desc_value = job.progress_data.get('desc', 'Processing...') if job.progress_data else 'Processing...' + current_html_value = job.progress_data.get('html', make_progress_bar_html(0, 'Processing...')) if job.progress_data else make_progress_bar_html(0, 'Processing...') + + if current_preview_value is not None and (current_preview_value is not last_preview or force_update): + last_preview = current_preview_value + # Always update if force_update is true, or if it's time for a periodic update + if force_update or update_needed: + last_progress_update_time = current_time + force_update = False + right_preview, top_preview = get_preview_updates(last_preview) + yield job.result, right_preview, top_preview, current_desc_value, current_html_value, gr.update(interactive=True), button_update + + # Fallback for periodic update if no new progress data but job is still running + elif current_time - last_progress_update_time > 0.5: # More frequent fallback update + last_progress_update_time = current_time + force_update = False # Reset force_update after a yield + current_desc_value = job.progress_data.get('desc', 'Processing...') if job.progress_data else 'Processing...' + current_html_value = job.progress_data.get('html', make_progress_bar_html(0, 'Processing...')) if job.progress_data else make_progress_bar_html(0, 'Processing...') + right_preview, top_preview = get_preview_updates(last_preview) + yield job.result, right_preview, top_preview, current_desc_value, current_html_value, gr.update(interactive=True), button_update + + elif job.status == JobStatus.COMPLETED: + # Show the final video and reset the button text + right_preview, top_preview = get_preview_updates(last_preview) + yield job.result, right_preview, top_preview, 'Completed', make_progress_bar_html(100, 'Completed'), gr.update(value="🚀 Add to Queue"), gr.update(interactive=True, value="❌ Cancel Current Job", visible=False) + break + + elif job.status == JobStatus.FAILED: + # Show error and reset the button text + right_preview, top_preview = get_preview_updates(last_preview) + yield job.result, right_preview, top_preview, f'Error: {job.error}', make_progress_bar_html(0, 'Failed'), gr.update(value="🚀 Add to Queue"), gr.update(interactive=True, value="❌ Cancel Current Job", visible=False) + break + + elif job.status == JobStatus.CANCELLED: + # Show cancelled message and reset the button text + right_preview, top_preview = get_preview_updates(last_preview) + yield job.result, right_preview, top_preview, 'Job cancelled', make_progress_bar_html(0, 'Cancelled'), gr.update(interactive=True), gr.update(interactive=True, value="❌ Cancel Current Job", visible=False) + break + + # Update last_job_status for the next iteration + last_job_status = job.status + + # Wait a bit before checking again + time.sleep(0.05) # Reduced wait time for more responsive updates + + +# Set Gradio temporary directory from settings +os.environ["GRADIO_TEMP_DIR"] = settings.get("gradio_temp_dir") + +# Create the interface +interface = create_interface( + process_fn=process, + monitor_fn=monitor_job, + end_process_fn=end_process, + update_queue_status_fn=update_queue_status, + load_lora_file_fn=load_lora_file, + job_queue=job_queue, + settings=settings, + lora_names=lora_names # Explicitly pass the found LoRA names +) + +# Launch the interface +interface.launch( + server_name=args.server, + server_port=args.port, + share=args.share, + inbrowser=args.inbrowser, + allowed_paths=[settings.get("output_dir"), settings.get("metadata_dir")], +) diff --git a/update.bat b/update.bat new file mode 100644 index 0000000000000000000000000000000000000000..4c37cbf799e611bb70b1dcca9c23664da3ff965a --- /dev/null +++ b/update.bat @@ -0,0 +1,65 @@ +@echo off +echo FramePack-Studio Update Script + +REM Check if Git is installed (basic check) +where git >nul 2>&1 +if %errorlevel% neq 0 ( + echo Error: Git is not installed or not in your PATH. Unable to update. + goto end +) + +REM Check if Python is installed (basic check) +where python >nul 2>&1 +if %errorlevel% neq 0 ( + echo Error: Python is not installed or not in your PATH. Unable to update dependencies. + REM Continue with Git pull, but warn about dependencies + echo Warning: Python is not available, skipping dependency update. + goto git_pull +) + + +:git_pull +echo Pulling latest changes from Git... +git pull + +REM Check if git pull was successful +if %errorlevel% neq 0 ( + echo Error: Failed to pull latest changes from Git. Please resolve any conflicts manually. + goto end +) + +echo Git pull successful. + +REM Attempt to update dependencies if Virtual Environment is available +if exist "%cd%/venv/Scripts/python.exe" ( + +if %errorlevel% equ 0 ( + echo Updating dependencies using pip... + REM This assumes there's a requirements.txt file in the root + REM Using --upgrade to update existing packages + "%cd%/venv/Scripts/python.exe" -m pip install --upgrade -r requirements.txt + + REM Check if pip update was successful + if %errorlevel% neq 0 ( + echo Warning: Failed to update dependencies. You may need to update them manually. + ) else ( + echo Dependency update successful. + ) +) else ( + echo Skipping dependency update as Python is not available. +) + +) else ( + +echo Error: Virtual Environment for Python not found. Did you install correctly? +goto end + +) + + + +echo Update complete. + +:end +echo Exiting update script. +pause \ No newline at end of file