multimodalart HF Staff commited on
Commit
4d68dfd
·
verified ·
1 Parent(s): 6ad4062

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -29
app.py CHANGED
@@ -11,7 +11,6 @@ import tempfile
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
  import shutil
14
- import math # For math.round, though built-in round works for floats
15
 
16
  from inference import (
17
  create_ltx_video_pipeline,
@@ -89,13 +88,56 @@ if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
89
  target_inference_device = "cuda"
90
  print(f"Target inference device: {target_inference_device}")
91
  pipeline_instance.to(target_inference_device)
92
- if latent_upsampler_instance: # Check if it was created before moving
93
  latent_upsampler_instance.to(target_inference_device)
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  @spaces.GPU
96
  def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath,
97
  height_ui, width_ui, mode,
98
- ui_steps, duration_ui, # << CHANGED from num_frames_ui
99
  ui_frames_to_use,
100
  seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
101
  progress=gr.Progress(track_tqdm=True)):
@@ -104,33 +146,25 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
104
  seed_ui = random.randint(0, 2**32 - 1)
105
  seed_everething(int(seed_ui))
106
 
107
- # Convert duration_ui (seconds) to actual_num_frames (N*8+1 format)
108
  target_frames_ideal = duration_ui * FPS
109
  target_frames_rounded = round(target_frames_ideal)
110
- if target_frames_rounded < 1: # ensure positive for calculation
111
  target_frames_rounded = 1
112
 
113
- # Calculate N for N*8+1, ensuring it's rounded to the nearest integer
114
- # (target_frames_rounded - 1) could be float if target_frames_rounded is float
115
  n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
116
  actual_num_frames = int(n_val * 8 + 1)
117
 
118
- # Clamp to the allowed min (9) and max (MAX_NUM_FRAMES) N*8+1 values
119
  actual_num_frames = max(9, actual_num_frames)
120
  actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
121
 
122
  actual_height = int(height_ui)
123
  actual_width = int(width_ui)
124
- # actual_num_frames is now calculated above
125
 
126
  height_padded = ((actual_height - 1) // 32 + 1) * 32
127
  width_padded = ((actual_width - 1) // 32 + 1) * 32
128
- # This padding ensures the model gets a frame count that is N*8+1
129
- # Since actual_num_frames is already N*8+1, this should preserve it.
130
  num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
131
  if num_frames_padded != actual_num_frames:
132
  print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
133
- # This case should ideally not happen if actual_num_frames is correctly N*8+1 and >= 9.
134
 
135
  padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
136
 
@@ -139,7 +173,7 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
139
  "negative_prompt": negative_prompt,
140
  "height": height_padded,
141
  "width": width_padded,
142
- "num_frames": num_frames_padded, # Use the padded value for the model
143
  "frame_rate": int(FPS),
144
  "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
145
  "output_type": "pt",
@@ -184,7 +218,7 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
184
  media_path=input_video_filepath,
185
  height=actual_height,
186
  width=actual_width,
187
- max_frames=int(ui_frames_to_use), # This is from a separate slider for V2V
188
  padding=padding_values
189
  ).to(target_inference_device)
190
  except Exception as e:
@@ -192,15 +226,10 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
192
  raise gr.Error(f"Could not load video: {e}")
193
 
194
  print(f"Moving models to {target_inference_device} for inference (if not already there)...")
195
- # Models are moved globally once, no need to move per call unless strategy changes.
196
- # pipeline_instance.to(target_inference_device)
197
- # if latent_upsampler_instance:
198
- # latent_upsampler_instance.to(target_inference_device)
199
 
200
  active_latent_upsampler = None
201
  if improve_texture_flag and latent_upsampler_instance:
202
  active_latent_upsampler = latent_upsampler_instance
203
- #print("Models moved.")
204
 
205
  result_images_tensor = None
206
  if improve_texture_flag:
@@ -230,7 +259,6 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
230
  single_pass_call_kwargs = call_kwargs.copy()
231
  single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
232
  single_pass_call_kwargs["num_inference_steps"] = int(ui_steps)
233
- # These keys might not exist if improve_texture_flag is false from the start of call_kwargs
234
  single_pass_call_kwargs.pop("first_pass", None)
235
  single_pass_call_kwargs.pop("second_pass", None)
236
  single_pass_call_kwargs.pop("downscale_factor", None)
@@ -245,7 +273,6 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
245
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
246
  slice_w_end = -pad_right if pad_right > 0 else None
247
 
248
- # Crop to actual_num_frames, which is the desired output length
249
  result_images_tensor = result_images_tensor[
250
  :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
251
  ]
@@ -297,6 +324,7 @@ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath
297
 
298
  return output_video_path
299
 
 
300
  # --- Gradio UI Definition ---
301
  css="""
302
  #col-container {
@@ -308,6 +336,7 @@ css="""
308
  with gr.Blocks(css=css) as demo:
309
  gr.Markdown("# LTX Video 0.9.7 Distilled")
310
  gr.Markdown("Fast high quality video generation. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](#)")
 
311
  with gr.Row():
312
  with gr.Column():
313
  with gr.Tab("image-to-video") as image_tab:
@@ -322,7 +351,7 @@ with gr.Blocks(css=css) as demo:
322
  t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
323
  with gr.Tab("video-to-video") as video_tab:
324
  image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
325
- video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"])
326
  frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames to use for conditioning/transformation. Must be N*8+1.")
327
  v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
328
  v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
@@ -347,26 +376,90 @@ with gr.Blocks(css=css) as demo:
347
  randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=False)
348
  with gr.Row():
349
  guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
350
- default_steps = len(PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps", [1]*7)) # Default to 7 if not found
351
  steps_input = gr.Slider(label="Inference Steps (for first pass if multi-scale)", minimum=1, maximum=30, value=default_steps, step=1, info="Number of denoising steps. More steps can improve quality but increase time. If YAML defines 'timesteps' for a pass, this UI value is ignored for that pass.")
352
  with gr.Row():
353
- height_input = gr.Slider(label="Height", value=512, step=32, minimum=256, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
354
- width_input = gr.Slider(label="Width", value=704, step=32, minimum=256, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
- # --- UPDATED INPUT LISTS ---
357
  t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
358
  height_input, width_input, gr.State("text-to-video"),
359
- steps_input, duration_input, gr.State(0), # Replaced num_frames_input with duration_input
360
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
361
 
362
  i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
363
  height_input, width_input, gr.State("image-to-video"),
364
- steps_input, duration_input, gr.State(0), # Replaced num_frames_input with duration_input
365
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
366
 
367
  v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
368
  height_input, width_input, gr.State("video-to-video"),
369
- steps_input, duration_input, frames_to_use, # Replaced num_frames_input with duration_input
370
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
371
 
372
  t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video], api_name="text_to_video")
 
11
  from PIL import Image
12
  from huggingface_hub import hf_hub_download
13
  import shutil
 
14
 
15
  from inference import (
16
  create_ltx_video_pipeline,
 
88
  target_inference_device = "cuda"
89
  print(f"Target inference device: {target_inference_device}")
90
  pipeline_instance.to(target_inference_device)
91
+ if latent_upsampler_instance:
92
  latent_upsampler_instance.to(target_inference_device)
93
 
94
+
95
+ # --- Helper function for dimension calculation ---
96
+ MIN_DIM_SLIDER = 256 # As defined in the sliders minimum attribute
97
+ TARGET_FIXED_SIDE = 512 # Desired fixed side length as per requirement
98
+
99
+ def calculate_new_dimensions(orig_w, orig_h):
100
+ """
101
+ Calculates new dimensions for height and width sliders based on original media dimensions.
102
+ Ensures one side is TARGET_FIXED_SIDE, the other is scaled proportionally,
103
+ both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
104
+ """
105
+ if orig_w == 0 or orig_h == 0:
106
+ # Default to TARGET_FIXED_SIDE square if original dimensions are invalid
107
+ return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
108
+
109
+ if orig_w >= orig_h: # Landscape or square
110
+ new_h = TARGET_FIXED_SIDE
111
+ aspect_ratio = orig_w / orig_h
112
+ new_w_ideal = new_h * aspect_ratio
113
+
114
+ # Round to nearest multiple of 32
115
+ new_w = round(new_w_ideal / 32) * 32
116
+
117
+ # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
118
+ new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
119
+ # Ensure new_h is also clamped (TARGET_FIXED_SIDE should be within these bounds if configured correctly)
120
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
121
+ else: # Portrait
122
+ new_w = TARGET_FIXED_SIDE
123
+ aspect_ratio = orig_h / orig_w # Use H/W ratio for portrait scaling
124
+ new_h_ideal = new_w * aspect_ratio
125
+
126
+ # Round to nearest multiple of 32
127
+ new_h = round(new_h_ideal / 32) * 32
128
+
129
+ # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
130
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
131
+ # Ensure new_w is also clamped
132
+ new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
133
+
134
+ return int(new_h), int(new_w)
135
+
136
+
137
  @spaces.GPU
138
  def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath,
139
  height_ui, width_ui, mode,
140
+ ui_steps, duration_ui,
141
  ui_frames_to_use,
142
  seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
143
  progress=gr.Progress(track_tqdm=True)):
 
146
  seed_ui = random.randint(0, 2**32 - 1)
147
  seed_everething(int(seed_ui))
148
 
 
149
  target_frames_ideal = duration_ui * FPS
150
  target_frames_rounded = round(target_frames_ideal)
151
+ if target_frames_rounded < 1:
152
  target_frames_rounded = 1
153
 
 
 
154
  n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
155
  actual_num_frames = int(n_val * 8 + 1)
156
 
 
157
  actual_num_frames = max(9, actual_num_frames)
158
  actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
159
 
160
  actual_height = int(height_ui)
161
  actual_width = int(width_ui)
 
162
 
163
  height_padded = ((actual_height - 1) // 32 + 1) * 32
164
  width_padded = ((actual_width - 1) // 32 + 1) * 32
 
 
165
  num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
166
  if num_frames_padded != actual_num_frames:
167
  print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
 
168
 
169
  padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
170
 
 
173
  "negative_prompt": negative_prompt,
174
  "height": height_padded,
175
  "width": width_padded,
176
+ "num_frames": num_frames_padded,
177
  "frame_rate": int(FPS),
178
  "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
179
  "output_type": "pt",
 
218
  media_path=input_video_filepath,
219
  height=actual_height,
220
  width=actual_width,
221
+ max_frames=int(ui_frames_to_use),
222
  padding=padding_values
223
  ).to(target_inference_device)
224
  except Exception as e:
 
226
  raise gr.Error(f"Could not load video: {e}")
227
 
228
  print(f"Moving models to {target_inference_device} for inference (if not already there)...")
 
 
 
 
229
 
230
  active_latent_upsampler = None
231
  if improve_texture_flag and latent_upsampler_instance:
232
  active_latent_upsampler = latent_upsampler_instance
 
233
 
234
  result_images_tensor = None
235
  if improve_texture_flag:
 
259
  single_pass_call_kwargs = call_kwargs.copy()
260
  single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
261
  single_pass_call_kwargs["num_inference_steps"] = int(ui_steps)
 
262
  single_pass_call_kwargs.pop("first_pass", None)
263
  single_pass_call_kwargs.pop("second_pass", None)
264
  single_pass_call_kwargs.pop("downscale_factor", None)
 
273
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
274
  slice_w_end = -pad_right if pad_right > 0 else None
275
 
 
276
  result_images_tensor = result_images_tensor[
277
  :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
278
  ]
 
324
 
325
  return output_video_path
326
 
327
+
328
  # --- Gradio UI Definition ---
329
  css="""
330
  #col-container {
 
336
  with gr.Blocks(css=css) as demo:
337
  gr.Markdown("# LTX Video 0.9.7 Distilled")
338
  gr.Markdown("Fast high quality video generation. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-2b-0.9.6-distilled-04-25.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](#)")
339
+
340
  with gr.Row():
341
  with gr.Column():
342
  with gr.Tab("image-to-video") as image_tab:
 
351
  t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
352
  with gr.Tab("video-to-video") as video_tab:
353
  image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
354
+ video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"]) # type defaults to filepath
355
  frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames to use for conditioning/transformation. Must be N*8+1.")
356
  v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
357
  v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
 
376
  randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=False)
377
  with gr.Row():
378
  guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
379
+ default_steps = len(PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps", [1]*7))
380
  steps_input = gr.Slider(label="Inference Steps (for first pass if multi-scale)", minimum=1, maximum=30, value=default_steps, step=1, info="Number of denoising steps. More steps can improve quality but increase time. If YAML defines 'timesteps' for a pass, this UI value is ignored for that pass.")
381
  with gr.Row():
382
+ height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
383
+ width_input = gr.Slider(label="Width", value=704, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
384
+
385
+
386
+ # --- Event handlers for updating dimensions on upload ---
387
+ def handle_image_upload_for_dims(image_filepath, current_h, current_w):
388
+ if not image_filepath: # Image cleared or no image initially
389
+ # Keep current slider values if image is cleared or no input
390
+ return gr.update(value=current_h), gr.update(value=current_w)
391
+ try:
392
+ img = Image.open(image_filepath)
393
+ orig_w, orig_h = img.size
394
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
395
+ return gr.update(value=new_h), gr.update(value=new_w)
396
+ except Exception as e:
397
+ print(f"Error processing image for dimension update: {e}")
398
+ # Keep current slider values on error
399
+ return gr.update(value=current_h), gr.update(value=current_w)
400
+
401
+ def handle_video_upload_for_dims(video_filepath, current_h, current_w):
402
+ if not video_filepath: # Video cleared or no video initially
403
+ return gr.update(value=current_h), gr.update(value=current_w)
404
+ try:
405
+ # Ensure video_filepath is a string for os.path.exists and imageio
406
+ video_filepath_str = str(video_filepath)
407
+ if not os.path.exists(video_filepath_str):
408
+ print(f"Video file path does not exist for dimension update: {video_filepath_str}")
409
+ return gr.update(value=current_h), gr.update(value=current_w)
410
+
411
+ orig_w, orig_h = -1, -1
412
+ with imageio.get_reader(video_filepath_str) as reader:
413
+ meta = reader.get_meta_data()
414
+ if 'size' in meta:
415
+ orig_w, orig_h = meta['size']
416
+ else:
417
+ # Fallback: read first frame if 'size' not in metadata
418
+ try:
419
+ first_frame = reader.get_data(0)
420
+ # Shape is (h, w, c) for frames
421
+ orig_h, orig_w = first_frame.shape[0], first_frame.shape[1]
422
+ except Exception as e_frame:
423
+ print(f"Could not get video size from metadata or first frame: {e_frame}")
424
+ return gr.update(value=current_h), gr.update(value=current_w)
425
+
426
+ if orig_w == -1 or orig_h == -1: # If dimensions couldn't be determined
427
+ print(f"Could not determine dimensions for video: {video_filepath_str}")
428
+ return gr.update(value=current_h), gr.update(value=current_w)
429
+
430
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
431
+ return gr.update(value=new_h), gr.update(value=new_w)
432
+ except Exception as e:
433
+ # Log type of video_filepath for debugging if it's not a path-like string
434
+ print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
435
+ return gr.update(value=current_h), gr.update(value=current_w)
436
+
437
+ # Attach upload handlers
438
+ image_i2v.upload(
439
+ fn=handle_image_upload_for_dims,
440
+ inputs=[image_i2v, height_input, width_input],
441
+ outputs=[height_input, width_input]
442
+ )
443
+ video_v2v.upload(
444
+ fn=handle_video_upload_for_dims,
445
+ inputs=[video_v2v, height_input, width_input],
446
+ outputs=[height_input, width_input]
447
+ )
448
 
449
+ # --- INPUT LISTS (remain the same structurally) ---
450
  t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
451
  height_input, width_input, gr.State("text-to-video"),
452
+ steps_input, duration_input, gr.State(0),
453
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
454
 
455
  i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
456
  height_input, width_input, gr.State("image-to-video"),
457
+ steps_input, duration_input, gr.State(0),
458
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
459
 
460
  v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
461
  height_input, width_input, gr.State("video-to-video"),
462
+ steps_input, duration_input, frames_to_use,
463
  seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
464
 
465
  t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video], api_name="text_to_video")