rahul7star commited on
Commit
021f101
Β·
verified Β·
1 Parent(s): 05fcd0f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +733 -0
app.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers_helper.hf_login import login
2
+
3
+ import json
4
+ import os
5
+ import shutil
6
+ from pathlib import PurePath, Path
7
+ import time
8
+ import argparse
9
+ import traceback
10
+ import einops
11
+ import numpy as np
12
+ import torch
13
+ import datetime
14
+ import spaces
15
+ # Version information
16
+ from modules.version import APP_VERSION
17
+
18
+ # Set environment variables
19
+ os.environ['HF_HOME'] = os.path.abspath(os.path.realpath(os.path.join(os.path.dirname(__file__), './hf_download')))
20
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Prevent tokenizers parallelism warning
21
+
22
+
23
+
24
+ import gradio as gr
25
+ from PIL import Image
26
+ from PIL.PngImagePlugin import PngInfo
27
+ from diffusers import AutoencoderKLHunyuanVideo
28
+ from transformers import LlamaModel, CLIPTextModel, LlamaTokenizerFast, CLIPTokenizer
29
+ from diffusers_helper.hunyuan import encode_prompt_conds, vae_decode, vae_encode, vae_decode_fake
30
+ from diffusers_helper.utils import save_bcthw_as_mp4, crop_or_pad_yield_mask, soft_append_bcthw, resize_and_center_crop, generate_timestamp
31
+ from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
32
+ from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
33
+ from diffusers_helper.memory import cpu, gpu, get_cuda_free_memory_gb, move_model_to_device_with_memory_preservation, offload_model_from_device_for_memory_preservation, fake_diffusers_current_device, DynamicSwapInstaller, unload_complete_models, load_model_as_complete
34
+ from diffusers_helper.thread_utils import AsyncStream
35
+ from diffusers_helper.gradio.progress_bar import make_progress_bar_html
36
+ from transformers import SiglipImageProcessor, SiglipVisionModel
37
+ from diffusers_helper.clip_vision import hf_clip_vision_encode
38
+ from diffusers_helper.bucket_tools import find_nearest_bucket
39
+ from diffusers_helper import lora_utils
40
+ from diffusers_helper.lora_utils import load_lora, unload_all_loras
41
+
42
+ # Import model generators
43
+ from modules.generators import create_model_generator
44
+
45
+ # Global cache for prompt embeddings
46
+ prompt_embedding_cache = {}
47
+ # Import from modules
48
+ from modules.video_queue import VideoJobQueue, JobStatus
49
+ from modules.prompt_handler import parse_timestamped_prompt
50
+ from modules.interface import create_interface, format_queue_status
51
+ from modules.settings import Settings
52
+ from modules import DUMMY_LORA_NAME # Import the constant
53
+ from modules.pipelines.metadata_utils import create_metadata
54
+ from modules.pipelines.worker import worker
55
+
56
+ # Try to suppress annoyingly persistent Windows asyncio proactor errors
57
+ if os.name == 'nt': # Windows only
58
+ import asyncio
59
+ from functools import wraps
60
+
61
+ # Replace the problematic proactor event loop with selector event loop
62
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
63
+
64
+ # Patch the base transport's close method
65
+ def silence_event_loop_closed(func):
66
+ @wraps(func)
67
+ def wrapper(self, *args, **kwargs):
68
+ try:
69
+ return func(self, *args, **kwargs)
70
+ except RuntimeError as e:
71
+ if str(e) != 'Event loop is closed':
72
+ raise
73
+ return wrapper
74
+
75
+ # Apply the patch
76
+ if hasattr(asyncio.proactor_events._ProactorBasePipeTransport, '_call_connection_lost'):
77
+ asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost = silence_event_loop_closed(
78
+ asyncio.proactor_events._ProactorBasePipeTransport._call_connection_lost)
79
+
80
+ # ADDED: Debug function to verify LoRA state
81
+ def verify_lora_state(transformer, label=""):
82
+ """Debug function to verify the state of LoRAs in a transformer model"""
83
+ if transformer is None:
84
+ print(f"[{label}] Transformer is None, cannot verify LoRA state")
85
+ return
86
+
87
+ has_loras = False
88
+ if hasattr(transformer, 'peft_config'):
89
+ adapter_names = list(transformer.peft_config.keys()) if transformer.peft_config else []
90
+ if adapter_names:
91
+ has_loras = True
92
+ print(f"[{label}] Transformer has LoRAs: {', '.join(adapter_names)}")
93
+ else:
94
+ print(f"[{label}] Transformer has no LoRAs in peft_config")
95
+ else:
96
+ print(f"[{label}] Transformer has no peft_config attribute")
97
+
98
+ # Check for any LoRA modules
99
+ for name, module in transformer.named_modules():
100
+ if hasattr(module, 'lora_A') and module.lora_A:
101
+ has_loras = True
102
+ # print(f"[{label}] Found lora_A in module {name}")
103
+ if hasattr(module, 'lora_B') and module.lora_B:
104
+ has_loras = True
105
+ # print(f"[{label}] Found lora_B in module {name}")
106
+
107
+ if not has_loras:
108
+ print(f"[{label}] No LoRA components found in transformer")
109
+
110
+
111
+ parser = argparse.ArgumentParser()
112
+ parser.add_argument('--share', action='store_true')
113
+ parser.add_argument("--server", type=str, default='0.0.0.0')
114
+ parser.add_argument("--port", type=int, required=False)
115
+ parser.add_argument("--inbrowser", action='store_true')
116
+ parser.add_argument("--lora", type=str, default=None, help="Lora path (comma separated for multiple)")
117
+ parser.add_argument("--offline", action='store_true', help="Run in offline mode")
118
+ args = parser.parse_args()
119
+
120
+ print(args)
121
+
122
+ if args.offline:
123
+ print("Offline mode enabled.")
124
+ os.environ['HF_HUB_OFFLINE'] = '1'
125
+ else:
126
+ if 'HF_HUB_OFFLINE' in os.environ:
127
+ del os.environ['HF_HUB_OFFLINE']
128
+
129
+ free_mem_gb = get_cuda_free_memory_gb(gpu)
130
+ high_vram = free_mem_gb > 60
131
+
132
+ print(f'Free VRAM {free_mem_gb} GB')
133
+ print(f'High-VRAM Mode: {high_vram}')
134
+
135
+ # Load models
136
+ text_encoder = LlamaModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder', torch_dtype=torch.float16).cpu()
137
+ text_encoder_2 = CLIPTextModel.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='text_encoder_2', torch_dtype=torch.float16).cpu()
138
+ tokenizer = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer')
139
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='tokenizer_2')
140
+ vae = AutoencoderKLHunyuanVideo.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder='vae', torch_dtype=torch.float16).cpu()
141
+
142
+ feature_extractor = SiglipImageProcessor.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='feature_extractor')
143
+ image_encoder = SiglipVisionModel.from_pretrained("lllyasviel/flux_redux_bfl", subfolder='image_encoder', torch_dtype=torch.float16).cpu()
144
+
145
+ # Initialize model generator placeholder
146
+ current_generator = None # Will hold the currently active model generator
147
+
148
+ # Load models based on VRAM availability later
149
+
150
+ # Configure models
151
+ vae.eval()
152
+ text_encoder.eval()
153
+ text_encoder_2.eval()
154
+ image_encoder.eval()
155
+
156
+ if not high_vram:
157
+ vae.enable_slicing()
158
+ vae.enable_tiling()
159
+
160
+
161
+ vae.to(dtype=torch.float16)
162
+ image_encoder.to(dtype=torch.float16)
163
+ text_encoder.to(dtype=torch.float16)
164
+ text_encoder_2.to(dtype=torch.float16)
165
+
166
+ vae.requires_grad_(False)
167
+ text_encoder.requires_grad_(False)
168
+ text_encoder_2.requires_grad_(False)
169
+ image_encoder.requires_grad_(False)
170
+
171
+ # Create lora directory if it doesn't exist
172
+ lora_dir = os.path.join(os.path.dirname(__file__), 'loras')
173
+ os.makedirs(lora_dir, exist_ok=True)
174
+
175
+ # Initialize LoRA support - moved scanning after settings load
176
+ lora_names = []
177
+ lora_values = [] # This seems unused for population, might be related to weights later
178
+
179
+ script_dir = os.path.dirname(os.path.abspath(__file__))
180
+
181
+ # Define default LoRA folder path relative to the script directory (used if setting is missing)
182
+ default_lora_folder = os.path.join(script_dir, "loras")
183
+ os.makedirs(default_lora_folder, exist_ok=True) # Ensure default exists
184
+
185
+ if not high_vram:
186
+ # DynamicSwapInstaller is same as huggingface's enable_sequential_offload but 3x faster
187
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
188
+ else:
189
+ text_encoder.to(gpu)
190
+ text_encoder_2.to(gpu)
191
+ image_encoder.to(gpu)
192
+ vae.to(gpu)
193
+
194
+ stream = AsyncStream()
195
+
196
+ outputs_folder = './outputs/'
197
+ os.makedirs(outputs_folder, exist_ok=True)
198
+
199
+ # Initialize settings
200
+ settings = Settings()
201
+
202
+ # NEW: auto-cleanup on start-up option in Settings
203
+ if settings.get("auto_cleanup_on_startup", False):
204
+ print("--- Running Automatic Startup Cleanup ---")
205
+
206
+ # Import the processor instance
207
+ from modules.toolbox_app import tb_processor
208
+
209
+ # Call the single cleanup function and print its summary.
210
+ cleanup_summary = tb_processor.tb_clear_temporary_files()
211
+ print(f"{cleanup_summary}") # This cleaner print handles the multiline string well
212
+
213
+ print("--- Startup Cleanup Complete ---")
214
+
215
+ # --- Populate LoRA names AFTER settings are loaded ---
216
+ lora_folder_from_settings: str = settings.get("lora_dir", default_lora_folder) # Use setting, fallback to default
217
+ print(f"Scanning for LoRAs in: {lora_folder_from_settings}")
218
+ if os.path.isdir(lora_folder_from_settings):
219
+ try:
220
+ for root, _, files in os.walk(lora_folder_from_settings):
221
+ for file in files:
222
+ if file.endswith('.safetensors') or file.endswith('.pt'):
223
+ lora_relative_path = os.path.relpath(os.path.join(root, file), lora_folder_from_settings)
224
+ lora_name = str(PurePath(lora_relative_path).with_suffix(''))
225
+ lora_names.append(lora_name)
226
+ print(f"Found LoRAs: {lora_names}")
227
+ # Temp solution for only 1 lora
228
+ if len(lora_names) == 1:
229
+ lora_names.append(DUMMY_LORA_NAME)
230
+ except Exception as e:
231
+ print(f"Error scanning LoRA directory '{lora_folder_from_settings}': {e}")
232
+ else:
233
+ print(f"LoRA directory not found: {lora_folder_from_settings}")
234
+ # --- End LoRA population ---
235
+
236
+
237
+ # Create job queue
238
+ job_queue = VideoJobQueue()
239
+
240
+
241
+
242
+ # Function to load a LoRA file
243
+ def load_lora_file(lora_file: str | PurePath):
244
+ if not lora_file:
245
+ return None, "No file selected"
246
+
247
+ try:
248
+ # Get the filename from the path
249
+ lora_path = PurePath(lora_file)
250
+ lora_name = lora_path.name
251
+
252
+ # Copy the file to the lora directory
253
+ lora_dest = PurePath(lora_dir, lora_path)
254
+ import shutil
255
+ shutil.copy(lora_file, lora_dest)
256
+
257
+ # Load the LoRA
258
+ global current_generator, lora_names
259
+ if current_generator is None:
260
+ return None, "Error: No model loaded to apply LoRA to. Generate something first."
261
+
262
+ # Unload any existing LoRAs first
263
+ current_generator.unload_loras()
264
+
265
+ # Load the single LoRA
266
+ selected_loras = [lora_path.stem]
267
+ current_generator.load_loras(selected_loras, lora_dir, selected_loras)
268
+
269
+ # Add to lora_names if not already there
270
+ lora_base_name = lora_path.stem
271
+ if lora_base_name not in lora_names:
272
+ lora_names.append(lora_base_name)
273
+
274
+ # Get the current device of the transformer
275
+ device = next(current_generator.transformer.parameters()).device
276
+
277
+ # Move all LoRA adapters to the same device as the base model
278
+ current_generator.move_lora_adapters_to_device(device)
279
+
280
+ print(f"Loaded LoRA: {lora_name} to {current_generator.get_model_name()} model")
281
+
282
+ return gr.update(choices=lora_names), f"Successfully loaded LoRA: {lora_name}"
283
+ except Exception as e:
284
+ print(f"Error loading LoRA: {e}")
285
+ return None, f"Error loading LoRA: {e}"
286
+
287
+ @torch.no_grad()
288
+ def get_cached_or_encode_prompt(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, target_device):
289
+ """
290
+ Retrieves prompt embeddings from cache or encodes them if not found.
291
+ Stores encoded embeddings (on CPU) in the cache.
292
+ Returns embeddings moved to the target_device.
293
+ """
294
+ if prompt in prompt_embedding_cache:
295
+ print(f"Cache hit for prompt: {prompt[:60]}...")
296
+ llama_vec_cpu, llama_mask_cpu, clip_l_pooler_cpu = prompt_embedding_cache[prompt]
297
+ # Move cached embeddings (from CPU) to the target device
298
+ llama_vec = llama_vec_cpu.to(target_device)
299
+ llama_attention_mask = llama_mask_cpu.to(target_device) if llama_mask_cpu is not None else None
300
+ clip_l_pooler = clip_l_pooler_cpu.to(target_device)
301
+ return llama_vec, llama_attention_mask, clip_l_pooler
302
+ else:
303
+ print(f"Cache miss for prompt: {prompt[:60]}...")
304
+ llama_vec, clip_l_pooler = encode_prompt_conds(
305
+ prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2
306
+ )
307
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
308
+ # Store CPU copies in cache
309
+ prompt_embedding_cache[prompt] = (llama_vec.cpu(), llama_attention_mask.cpu() if llama_attention_mask is not None else None, clip_l_pooler.cpu())
310
+ # Return embeddings already on the target device (as encode_prompt_conds uses the model's device)
311
+ return llama_vec, llama_attention_mask, clip_l_pooler
312
+
313
+ # Set the worker function for the job queue - using the imported worker from modules/pipelines/worker.py
314
+ job_queue.set_worker_function(worker)
315
+
316
+ def get_duration(model_type,
317
+ input_image,
318
+ end_frame_image, # NEW
319
+ end_frame_strength, # NEW
320
+ prompt_text,
321
+ n_prompt,
322
+ seed,
323
+ total_second_length,
324
+ latent_window_size,
325
+ steps,
326
+ cfg,
327
+ gs,
328
+ rs,
329
+ use_teacache,
330
+ teacache_num_steps,
331
+ teacache_rel_l1_thresh,
332
+ use_magcache,
333
+ magcache_threshold,
334
+ magcache_max_consecutive_skips,
335
+ magcache_retention_ratio,
336
+ blend_sections,
337
+ latent_type,
338
+ clean_up_videos,
339
+ selected_loras,
340
+ resolutionW,
341
+ resolutionH,
342
+ input_image_path,
343
+ combine_with_source,
344
+ num_cleaned_frames,
345
+ *lora_args,
346
+ save_metadata_checked=True):
347
+ return total_second_length * 60
348
+
349
+ @spaces.GPU(duration=get_duration)
350
+ def process(
351
+ model_type,
352
+ input_image,
353
+ end_frame_image, # NEW
354
+ end_frame_strength, # NEW
355
+ prompt_text,
356
+ n_prompt,
357
+ seed,
358
+ total_second_length,
359
+ latent_window_size,
360
+ steps,
361
+ cfg,
362
+ gs,
363
+ rs,
364
+ use_teacache,
365
+ teacache_num_steps,
366
+ teacache_rel_l1_thresh,
367
+ use_magcache,
368
+ magcache_threshold,
369
+ magcache_max_consecutive_skips,
370
+ magcache_retention_ratio,
371
+ blend_sections,
372
+ latent_type,
373
+ clean_up_videos,
374
+ selected_loras,
375
+ resolutionW,
376
+ resolutionH,
377
+ input_image_path,
378
+ combine_with_source,
379
+ num_cleaned_frames,
380
+ *lora_args,
381
+ save_metadata_checked=True, # NEW: Parameter to control metadata saving
382
+ ):
383
+
384
+ # Create a blank black image if no
385
+ # Create a default image based on the selected latent_type
386
+ has_input_image = True
387
+ if input_image is None:
388
+ has_input_image = False
389
+ default_height, default_width = resolutionH, resolutionW
390
+ if latent_type == "White":
391
+ # Create a white image
392
+ input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
393
+ print("No input image provided. Using a blank white image.")
394
+
395
+ elif latent_type == "Noise":
396
+ # Create a noise image
397
+ input_image = np.random.randint(0, 256, (default_height, default_width, 3), dtype=np.uint8)
398
+ print("No input image provided. Using a random noise image.")
399
+
400
+ elif latent_type == "Green Screen":
401
+ # Create a green screen image with standard chroma key green (0, 177, 64)
402
+ input_image = np.zeros((default_height, default_width, 3), dtype=np.uint8)
403
+ input_image[:, :, 1] = 177 # Green channel
404
+ input_image[:, :, 2] = 64 # Blue channel
405
+ # Red channel remains 0
406
+ print("No input image provided. Using a standard chroma key green screen.")
407
+
408
+ else: # Default to "Black" or any other value
409
+ # Create a black image
410
+ input_image = np.zeros((default_height, default_width, 3), dtype=np.uint8)
411
+ print(f"No input image provided. Using a blank black image (latent_type: {latent_type}).")
412
+
413
+
414
+ # Handle input files - copy to input_files_dir to prevent them from being deleted by temp cleanup
415
+ input_files_dir = settings.get("input_files_dir")
416
+ os.makedirs(input_files_dir, exist_ok=True)
417
+
418
+ # Process input image (if it's a file path)
419
+ input_image_path = None
420
+ if isinstance(input_image, str) and os.path.exists(input_image):
421
+ # It's a file path, copy it to input_files_dir
422
+ filename = os.path.basename(input_image)
423
+ input_image_path = os.path.join(input_files_dir, f"{generate_timestamp()}_{filename}")
424
+ try:
425
+ shutil.copy2(input_image, input_image_path)
426
+ print(f"Copied input image to {input_image_path}")
427
+ # For Video model, we'll use the path
428
+ if model_type == "Video":
429
+ input_image = input_image_path
430
+ except Exception as e:
431
+ print(f"Error copying input image: {e}")
432
+
433
+ # Process end frame image (if it's a file path)
434
+ end_frame_image_path = None
435
+ if isinstance(end_frame_image, str) and os.path.exists(end_frame_image):
436
+ # It's a file path, copy it to input_files_dir
437
+ filename = os.path.basename(end_frame_image)
438
+ end_frame_image_path = os.path.join(input_files_dir, f"{generate_timestamp()}_{filename}")
439
+ try:
440
+ shutil.copy2(end_frame_image, end_frame_image_path)
441
+ print(f"Copied end frame image to {end_frame_image_path}")
442
+ except Exception as e:
443
+ print(f"Error copying end frame image: {e}")
444
+
445
+ # Extract lora_loaded_names from lora_args
446
+ lora_loaded_names = lora_args[0] if lora_args and len(lora_args) > 0 else []
447
+ lora_values = lora_args[1:] if lora_args and len(lora_args) > 1 else []
448
+
449
+ # Create job parameters
450
+ job_params = {
451
+ 'model_type': model_type,
452
+ 'input_image': input_image.copy() if hasattr(input_image, 'copy') else input_image, # Handle both image arrays and video paths
453
+ 'end_frame_image': end_frame_image.copy() if end_frame_image is not None else None,
454
+ 'end_frame_strength': end_frame_strength,
455
+ 'prompt_text': prompt_text,
456
+ 'n_prompt': n_prompt,
457
+ 'seed': seed,
458
+ 'total_second_length': total_second_length,
459
+ 'latent_window_size': latent_window_size,
460
+ 'latent_type': latent_type,
461
+ 'steps': steps,
462
+ 'cfg': cfg,
463
+ 'gs': gs,
464
+ 'rs': rs,
465
+ 'blend_sections': blend_sections,
466
+ 'use_teacache': use_teacache,
467
+ 'teacache_num_steps': teacache_num_steps,
468
+ 'teacache_rel_l1_thresh': teacache_rel_l1_thresh,
469
+ 'use_magcache': use_magcache,
470
+ 'magcache_threshold': magcache_threshold,
471
+ 'magcache_max_consecutive_skips': magcache_max_consecutive_skips,
472
+ 'magcache_retention_ratio': magcache_retention_ratio,
473
+ 'selected_loras': selected_loras,
474
+ 'has_input_image': has_input_image,
475
+ 'output_dir': settings.get("output_dir"),
476
+ 'metadata_dir': settings.get("metadata_dir"),
477
+ 'input_files_dir': input_files_dir, # Add input_files_dir to job parameters
478
+ 'input_image_path': input_image_path, # Add the path to the copied input image
479
+ 'end_frame_image_path': end_frame_image_path, # Add the path to the copied end frame image
480
+ 'resolutionW': resolutionW, # Add resolution parameter
481
+ 'resolutionH': resolutionH,
482
+ 'lora_loaded_names': lora_loaded_names,
483
+ 'combine_with_source': combine_with_source, # Add combine_with_source parameter
484
+ 'num_cleaned_frames': num_cleaned_frames,
485
+ 'save_metadata_checked': save_metadata_checked, # NEW: Add save_metadata_checked parameter
486
+ }
487
+
488
+ # Print teacache parameters for debugging
489
+ print(f"Teacache parameters: use_teacache={use_teacache}, teacache_num_steps={teacache_num_steps}, teacache_rel_l1_thresh={teacache_rel_l1_thresh}")
490
+
491
+ # Add LoRA values if provided - extract them from the tuple
492
+ if lora_values:
493
+ # Convert tuple to list
494
+ lora_values_list = list(lora_values)
495
+ job_params['lora_values'] = lora_values_list
496
+
497
+ # Add job to queue
498
+ job_id = job_queue.add_job(job_params)
499
+
500
+ # Set the generation_type attribute on the job object directly
501
+ job = job_queue.get_job(job_id)
502
+ if job:
503
+ job.generation_type = model_type # Set generation_type to model_type for display in queue
504
+ print(f"Added job {job_id} to queue")
505
+
506
+ queue_status = update_queue_status()
507
+ # Return immediately after adding to queue
508
+ # Return separate updates for start_button and end_button to prevent cross-contamination
509
+ return None, job_id, None, '', f'Job added to queue. Job ID: {job_id}', gr.update(value="πŸš€ Add to Queue", interactive=True), gr.update(value="❌ Cancel Current Job", interactive=True)
510
+
511
+
512
+
513
+ def end_process():
514
+ """Cancel the current running job and update the queue status"""
515
+ print("Cancelling current job")
516
+ with job_queue.lock:
517
+ if job_queue.current_job:
518
+ job_id = job_queue.current_job.id
519
+ print(f"Cancelling job {job_id}")
520
+
521
+ # Send the end signal to the job's stream
522
+ if job_queue.current_job.stream:
523
+ job_queue.current_job.stream.input_queue.push('end')
524
+
525
+ # Mark the job as cancelled
526
+ job_queue.current_job.status = JobStatus.CANCELLED
527
+ job_queue.current_job.completed_at = time.time() # Set completion time
528
+
529
+ # Force an update to the queue status
530
+ return update_queue_status()
531
+
532
+
533
+ def update_queue_status():
534
+ """Update queue status and refresh job positions"""
535
+ jobs = job_queue.get_all_jobs()
536
+ for job in jobs:
537
+ if job.status == JobStatus.PENDING:
538
+ job.queue_position = job_queue.get_queue_position(job.id)
539
+
540
+ # Make sure to update current running job info
541
+ if job_queue.current_job:
542
+ # Make sure the running job is showing status = RUNNING
543
+ job_queue.current_job.status = JobStatus.RUNNING
544
+
545
+ # Update the toolbar stats
546
+ pending_count = 0
547
+ running_count = 0
548
+ completed_count = 0
549
+
550
+ for job in jobs:
551
+ if hasattr(job, 'status'):
552
+ status = str(job.status)
553
+ if status == "JobStatus.PENDING":
554
+ pending_count += 1
555
+ elif status == "JobStatus.RUNNING":
556
+ running_count += 1
557
+ elif status == "JobStatus.COMPLETED":
558
+ completed_count += 1
559
+
560
+ return format_queue_status(jobs)
561
+
562
+
563
+ def monitor_job(job_id=None):
564
+ """
565
+ Monitor a specific job and update the UI with the latest video segment as soon as it's available.
566
+ If no job_id is provided, check if there's a current job in the queue.
567
+ ALWAYS shows the current running job, regardless of the job_id provided.
568
+ """
569
+ last_video = None # Track the last video file shown
570
+ last_job_status = None # Track the previous job status to detect status changes
571
+ last_progress_update_time = time.time() # Track when we last updated the progress
572
+ last_preview = None # Track the last preview image shown
573
+ force_update = True # Force an update on first iteration
574
+
575
+ # Flag to indicate we're waiting for a job transition
576
+ waiting_for_transition = False
577
+ transition_start_time = None
578
+ max_transition_wait = 5.0 # Maximum time to wait for transition in seconds
579
+
580
+ def get_preview_updates(preview_value):
581
+ """Create preview updates that respect the latents_display_top setting"""
582
+ display_top = settings.get("latents_display_top", False)
583
+ if display_top:
584
+ # Top display enabled: update top preview with value, don't update right preview
585
+ return gr.update(), preview_value if preview_value is not None else gr.update()
586
+ else:
587
+ # Right column display: update right preview with value, don't update top preview
588
+ return preview_value if preview_value is not None else gr.update(), gr.update()
589
+
590
+ while True:
591
+ # ALWAYS check if there's a current running job that's different from our tracked job_id
592
+ with job_queue.lock:
593
+ current_job = job_queue.current_job
594
+ if current_job and current_job.id != job_id and current_job.status == JobStatus.RUNNING:
595
+ # Always switch to the current running job
596
+ job_id = current_job.id
597
+ waiting_for_transition = False
598
+ force_update = True
599
+ # Yield a temporary update to show we're switching jobs
600
+ right_preview, top_preview = get_preview_updates(None)
601
+ yield last_video, right_preview, top_preview, '', 'Switching to current job...', gr.update(interactive=True), gr.update(value="❌ Cancel Current Job", visible=True)
602
+ continue
603
+
604
+ # Check if we're waiting for a job transition
605
+ if waiting_for_transition:
606
+ current_time = time.time()
607
+ # If we've been waiting too long, stop waiting
608
+ if current_time - transition_start_time > max_transition_wait:
609
+ waiting_for_transition = False
610
+
611
+ # Check one more time for a current job
612
+ with job_queue.lock:
613
+ current_job = job_queue.current_job
614
+ if current_job and current_job.status == JobStatus.RUNNING:
615
+ # Switch to whatever job is currently running
616
+ job_id = current_job.id
617
+ force_update = True
618
+ right_preview, top_preview = get_preview_updates(None)
619
+ yield last_video, right_preview, top_preview, '', 'Switching to current job...', gr.update(interactive=True), gr.update(value="❌ Cancel Current Job", visible=True)
620
+ continue
621
+ else:
622
+ # If still waiting, sleep briefly and continue
623
+ time.sleep(0.1)
624
+ continue
625
+
626
+ job = job_queue.get_job(job_id)
627
+ if not job:
628
+ # Correctly yield 7 items for the startup/no-job case
629
+ # This ensures the status text goes to the right component and the buttons are set correctly.
630
+ yield None, None, None, 'No job ID provided', '', gr.update(value="πŸš€ Add to Queue", interactive=True, visible=True), gr.update(interactive=False, visible=False)
631
+ return
632
+
633
+ # If a new video file is available, yield it immediately
634
+ if job.result and job.result != last_video:
635
+ last_video = job.result
636
+ # You can also update preview/progress here if desired
637
+ right_preview, top_preview = get_preview_updates(None)
638
+ yield last_video, right_preview, top_preview, '', '', gr.update(interactive=True), gr.update(interactive=True)
639
+
640
+ # Handle job status and progress
641
+ if job.status == JobStatus.PENDING:
642
+ position = job_queue.get_queue_position(job_id)
643
+ right_preview, top_preview = get_preview_updates(None)
644
+ yield last_video, right_preview, top_preview, '', f'Waiting in queue. Position: {position}', gr.update(interactive=True), gr.update(interactive=True)
645
+
646
+ elif job.status == JobStatus.RUNNING:
647
+ # Only reset the cancel button when a job transitions from another state to RUNNING
648
+ # This ensures we don't reset the button text during cancellation
649
+ if last_job_status != JobStatus.RUNNING:
650
+ # Check if the button text is already "Cancelling..." - if so, don't change it
651
+ # This prevents the button from changing back to "Cancel Current Job" during cancellation
652
+ button_update = gr.update(interactive=True, value="❌ Cancel Current Job", visible=True)
653
+ else:
654
+ # Keep current text and state - important to not override "Cancelling..." text
655
+ button_update = gr.update(interactive=True, visible=True)
656
+
657
+ # Check if we have progress data and if it's time to update
658
+ current_time = time.time()
659
+ update_needed = force_update or (current_time - last_progress_update_time > 0.05) # More frequent updates
660
+
661
+ # Always check for progress data, even if we don't have a preview yet
662
+ if job.progress_data and update_needed:
663
+ preview = job.progress_data.get('preview')
664
+ desc = job.progress_data.get('desc', '')
665
+ html = job.progress_data.get('html', '')
666
+
667
+ # Only update the preview if it has changed or we're forcing an update
668
+ # Ensure all components get an update
669
+ current_preview_value = job.progress_data.get('preview') if job.progress_data else None
670
+ current_desc_value = job.progress_data.get('desc', 'Processing...') if job.progress_data else 'Processing...'
671
+ current_html_value = job.progress_data.get('html', make_progress_bar_html(0, 'Processing...')) if job.progress_data else make_progress_bar_html(0, 'Processing...')
672
+
673
+ if current_preview_value is not None and (current_preview_value is not last_preview or force_update):
674
+ last_preview = current_preview_value
675
+ # Always update if force_update is true, or if it's time for a periodic update
676
+ if force_update or update_needed:
677
+ last_progress_update_time = current_time
678
+ force_update = False
679
+ right_preview, top_preview = get_preview_updates(last_preview)
680
+ yield job.result, right_preview, top_preview, current_desc_value, current_html_value, gr.update(interactive=True), button_update
681
+
682
+ # Fallback for periodic update if no new progress data but job is still running
683
+ elif current_time - last_progress_update_time > 0.5: # More frequent fallback update
684
+ last_progress_update_time = current_time
685
+ force_update = False # Reset force_update after a yield
686
+ current_desc_value = job.progress_data.get('desc', 'Processing...') if job.progress_data else 'Processing...'
687
+ current_html_value = job.progress_data.get('html', make_progress_bar_html(0, 'Processing...')) if job.progress_data else make_progress_bar_html(0, 'Processing...')
688
+ right_preview, top_preview = get_preview_updates(last_preview)
689
+ yield job.result, right_preview, top_preview, current_desc_value, current_html_value, gr.update(interactive=True), button_update
690
+
691
+ elif job.status == JobStatus.COMPLETED:
692
+ # Show the final video and reset the button text
693
+ right_preview, top_preview = get_preview_updates(last_preview)
694
+ yield job.result, right_preview, top_preview, 'Completed', make_progress_bar_html(100, 'Completed'), gr.update(value="πŸš€ Add to Queue"), gr.update(interactive=True, value="❌ Cancel Current Job", visible=False)
695
+ break
696
+
697
+ elif job.status == JobStatus.FAILED:
698
+ # Show error and reset the button text
699
+ right_preview, top_preview = get_preview_updates(last_preview)
700
+ yield job.result, right_preview, top_preview, f'Error: {job.error}', make_progress_bar_html(0, 'Failed'), gr.update(value="πŸš€ Add to Queue"), gr.update(interactive=True, value="❌ Cancel Current Job", visible=False)
701
+ break
702
+
703
+ elif job.status == JobStatus.CANCELLED:
704
+ # Show cancelled message and reset the button text
705
+ right_preview, top_preview = get_preview_updates(last_preview)
706
+ yield job.result, right_preview, top_preview, 'Job cancelled', make_progress_bar_html(0, 'Cancelled'), gr.update(interactive=True), gr.update(interactive=True, value="❌ Cancel Current Job", visible=False)
707
+ break
708
+
709
+ # Update last_job_status for the next iteration
710
+ last_job_status = job.status
711
+
712
+ # Wait a bit before checking again
713
+ time.sleep(0.05) # Reduced wait time for more responsive updates
714
+
715
+
716
+ # Set Gradio temporary directory from settings
717
+ os.environ["GRADIO_TEMP_DIR"] = settings.get("gradio_temp_dir")
718
+
719
+ # Create the interface
720
+ interface = create_interface(
721
+ process_fn=process,
722
+ monitor_fn=monitor_job,
723
+ end_process_fn=end_process,
724
+ update_queue_status_fn=update_queue_status,
725
+ load_lora_file_fn=load_lora_file,
726
+ job_queue=job_queue,
727
+ settings=settings,
728
+ lora_names=lora_names # Explicitly pass the found LoRA names
729
+ )
730
+
731
+ # Launch the interface
732
+
733
+ interface.launch(share=True)