multimodalart HF Staff commited on
Commit
d8bb216
·
verified ·
1 Parent(s): 3947f33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +320 -407
app.py CHANGED
@@ -1,382 +1,302 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  import numpy as np
 
5
  import os
6
  import yaml
7
- import random
8
- from PIL import Image
9
- import imageio # For export_to_video and reading video frames
10
  from pathlib import Path
 
 
 
11
  from huggingface_hub import hf_hub_download
 
12
 
13
- # --- LTX-Video Imports (from your provided codebase) ---
14
- from ltx_video.pipelines.pipeline_ltx_video import (
15
- ConditioningItem,
16
- LTXVideoPipeline,
17
- LTXMultiScalePipeline,
18
- )
19
- from ltx_video.models.autoencoders.vae_encode import vae_decode, vae_encode, un_normalize_latents, normalize_latents
20
  from inference import (
21
  create_ltx_video_pipeline,
22
  create_latent_upsampler,
23
- load_image_to_tensor_with_resize_and_crop, # Re-using for image conditioning
24
- load_media_file, # Re-using for video conditioning
25
- get_device,
26
  seed_everething,
 
27
  calculate_padding,
 
28
  )
 
29
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
30
- from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
31
- # --- End LTX-Video Imports ---
32
-
33
- # --- Diffusers/Original utils (keeping export_to_video for convenience if it works) ---
34
- from diffusers.utils import export_to_video # Keep if it works with PIL list
35
- # ---
36
-
37
- # --- Global Configuration & Model Loading ---
38
- DEVICE = get_device()
39
- MODEL_DIR = "downloaded_models" # Directory to store downloaded models
40
- Path(MODEL_DIR).mkdir(parents=True, exist_ok=True)
41
 
42
- # Load YAML configuration
43
- YAML_CONFIG_PATH = "configs/ltxv-13b-0.9.7-distilled.yaml" # Place this file in the same directory
44
- with open(YAML_CONFIG_PATH, "r") as f:
45
- PIPELINE_CONFIG_YAML = yaml.safe_load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Download and prepare model paths from YAML
48
- LTXV_MODEL_FILENAME = PIPELINE_CONFIG_YAML["checkpoint_path"]
49
  SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
50
- TEXT_ENCODER_PATH = PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"] # This is usually a repo name
51
-
52
- try:
53
- # Main LTX-Video model
54
- if not os.path.isfile(os.path.join(MODEL_DIR, LTXV_MODEL_FILENAME)):
55
- print(f"Downloading {LTXV_MODEL_FILENAME}...")
56
- ltxv_checkpoint_path = hf_hub_download(
57
- repo_id="LTX-Colab/LTX-Video-Preview", # Assuming the distilled model is also here or adjust repo_id
58
- filename=LTXV_MODEL_FILENAME,
59
- local_dir=MODEL_DIR,
60
- repo_type="model",
61
- )
62
- else:
63
- ltxv_checkpoint_path = os.path.join(MODEL_DIR, LTXV_MODEL_FILENAME)
64
-
65
- # Spatial Upsampler model
66
- if not os.path.isfile(os.path.join(MODEL_DIR, SPATIAL_UPSCALER_FILENAME)):
67
- print(f"Downloading {SPATIAL_UPSCALER_FILENAME}...")
68
- spatial_upsampler_path = hf_hub_download(
69
- repo_id="Lightricks/LTX-Video",
70
- filename=SPATIAL_UPSCALER_FILENAME,
71
- local_dir=MODEL_DIR,
72
- repo_type="model",
73
- )
74
- else:
75
- spatial_upsampler_path = os.path.join(MODEL_DIR, SPATIAL_UPSCALER_FILENAME)
76
- except Exception as e:
77
- print(f"Error downloading models: {e}")
78
- print("Please ensure model files are correctly specified and accessible.")
79
- # Depending on severity, you might want to exit or disable GPU features
80
- # For now, we'll let it proceed and potentially fail later if paths are invalid.
81
- ltxv_checkpoint_path = LTXV_MODEL_FILENAME # Fallback to filename if download fails
82
- spatial_upsampler_path = SPATIAL_UPSCALER_FILENAME
83
-
84
-
85
- print(f"Using LTX-Video checkpoint: {ltxv_checkpoint_path}")
86
- print(f"Using Spatial Upsampler: {spatial_upsampler_path}")
87
- print(f"Using Text Encoder: {TEXT_ENCODER_PATH}")
88
-
89
- # Create LTX-Video pipeline
90
- pipe = create_ltx_video_pipeline(
91
- ckpt_path=ltxv_checkpoint_path,
92
- precision=PIPELINE_CONFIG_YAML["precision"],
93
- text_encoder_model_name_or_path=TEXT_ENCODER_PATH,
94
- sampler=PIPELINE_CONFIG_YAML["sampler"], # "from_checkpoint" or specific sampler
95
- device=DEVICE,
96
- enhance_prompt=False, # Assuming Gradio controls this, or set based on YAML later
97
- )#.to(torch.bfloat16)
98
-
99
- # Create Latent Upsampler
100
- latent_upsampler = create_latent_upsampler(
101
- latent_upsampler_model_path=spatial_upsampler_path,
102
- device=DEVICE
103
  )
104
- #latent_upsampler = latent_upsampler.to(torch.bfloat16)
 
105
 
106
-
107
- # Multi-scale pipeline (wrapper)
108
- multi_scale_pipe = LTXMultiScalePipeline(
109
- video_pipeline=pipe,
110
- latent_upsampler=latent_upsampler
 
 
 
 
 
 
111
  )
112
- # --- End Global Configuration & Model Loading ---
113
-
114
-
115
- MAX_SEED = np.iinfo(np.int32).max
116
- MAX_IMAGE_SIZE = 2048 # Not strictly used here, but good to keep in mind
117
-
118
-
119
- def round_to_nearest_resolution_acceptable_by_vae(height, width, vae_scale_factor):
120
- # print("before rounding",height, width)
121
- height = height - (height % vae_scale_factor)
122
- width = width - (width % vae_scale_factor)
123
- # print("after rounding",height, width)
124
- return height, width
125
-
126
- @spaces.GPU
127
- def generate(prompt,
128
- negative_prompt,
129
- image_path, # Gradio gives filepath for Image component
130
- video_path, # Gradio gives filepath for Video component
131
- height,
132
- width,
133
- mode,
134
- steps, # This will map to num_inference_steps for the first pass
135
- num_frames,
136
- frames_to_use,
137
- seed,
138
- randomize_seed,
139
- guidance_scale,
140
- improve_texture=False, progress=gr.Progress(track_tqdm=True)):
141
 
142
- if randomize_seed:
143
- seed = random.randint(0, MAX_SEED)
144
- seed_everething(seed)
145
-
146
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
 
 
147
 
148
- # --- Prepare conditioning items ---
149
- conditioning_items_list = []
150
- input_media_for_vid2vid = None # For the specific vid2vid mode in LTX pipeline
151
 
152
- # Pad target dimensions
153
- # VAE scale factor is typically 8 for spatial, but LTX might have its own specific factor.
154
- # CausalVideoAutoencoder has spatial_downscale_factor and temporal_downscale_factor
155
- vae_spatial_scale_factor = pipe.vae.spatial_downscale_factor
156
- vae_temporal_scale_factor = pipe.vae.temporal_downscale_factor
 
157
 
158
- # Ensure target height/width are multiples of VAE spatial scale factor
159
- height_padded_target = ((height - 1) // vae_spatial_scale_factor + 1) * vae_spatial_scale_factor
160
- width_padded_target = ((width - 1) // vae_spatial_scale_factor + 1) * vae_spatial_scale_factor
161
-
162
- # Ensure num_frames is multiple of VAE temporal scale factor + 1 (for causal VAE)
163
- # (num_frames - 1) should be multiple of temporal_scale_factor for non-causal parts
164
- # For CausalVAE, it's often (N * temporal_factor) + 1 frames.
165
- # The inference script uses: num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
166
- # Assuming 8 is the temporal scale factor here for simplicity, adjust if different
167
- num_frames_padded_target = ((num_frames - 2) // vae_temporal_scale_factor + 1) * vae_temporal_scale_factor + 1
168
-
169
-
170
- padding_target = calculate_padding(height, width, height_padded_target, width_padded_target)
171
-
172
-
173
- if mode == "video-to-video" and video_path:
174
- # LTX pipeline's vid2vid uses `media_items` argument for the full video to transform
175
- # and `conditioning_items` for specific keyframes if needed.
176
- # Here, the Gradio's "video-to-video" seems to imply transforming the input video.
177
- input_media_for_vid2vid = load_media_file(
178
- media_path=video_path,
179
- height=height, # Original height before padding for loading
180
- width=width, # Original width
181
- max_frames=min(num_frames_padded_target, frames_to_use if frames_to_use > 0 else num_frames_padded_target),
182
- padding=padding_target, # Padding to make it compatible with VAE of target size
183
- )
184
- # If we also want to strongly condition on the first frame(s) of this video:
185
- conditioning_media = load_media_file(
186
- media_path=video_path,
187
- height=height, width=width,
188
- max_frames=min(frames_to_use if frames_to_use > 0 else 1, num_frames_padded_target), # Use specified frames or just the first
189
- padding=padding_target,
190
- just_crop=True # Crop to aspect ratio, then resize
191
- )
192
- conditioning_items_list.append(ConditioningItem(media_item=conditioning_media, media_frame_number=0, conditioning_strength=1.0))
193
-
194
- elif mode == "image-to-video" and image_path:
195
- conditioning_media = load_image_to_tensor_with_resize_and_crop(
196
- image_input=image_path,
197
- target_height=height, # Original height
198
- target_width=width # Original width
199
- )
200
- # Apply padding to the loaded tensor
201
- conditioning_media = torch.nn.functional.pad(conditioning_media, padding_target)
202
- conditioning_items_list.append(ConditioningItem(media_item=conditioning_media, media_frame_number=0, conditioning_strength=1.0))
203
 
204
- # else mode is "text-to-video", no explicit conditioning items unless defined elsewhere
205
-
206
- # --- Get pipeline parameters from YAML ---
207
- first_pass_config = PIPELINE_CONFIG_YAML.get("first_pass", {})
208
- second_pass_config = PIPELINE_CONFIG_YAML.get("second_pass", {})
209
- downscale_factor = PIPELINE_CONFIG_YAML.get("downscale_factor", 2/3)
210
-
211
- # Override steps from Gradio if provided, for the first pass
212
- if steps:
213
- # The YAML timesteps are specific, so overriding num_inference_steps might not be what we want
214
- # If YAML has `timesteps`, `num_inference_steps` is ignored by LTXVideoPipeline.
215
- # If YAML does not have `timesteps`, then `num_inference_steps` from Gradio will be used for the first pass.
216
- first_pass_config["num_inference_steps"] = steps
217
- # For distilled model, the second pass steps are usually very few, defined by its timesteps.
218
- # We won't override second_pass_config["num_inference_steps"] from the Gradio `steps`
219
- # as it's meant for the primary generation.
220
-
221
- # Determine initial generation dimensions (downscaled)
222
- # These are the dimensions for the *first pass* of the multi-scale pipeline
223
- initial_gen_height = int(height_padded_target * downscale_factor)
224
- initial_gen_width = int(width_padded_target * downscale_factor)
225
-
226
- initial_gen_height, initial_gen_width = round_to_nearest_resolution_acceptable_by_vae(
227
- initial_gen_height, initial_gen_width, vae_spatial_scale_factor
228
- )
229
 
230
- shared_pipeline_args = {
 
 
231
  "prompt": prompt,
232
  "negative_prompt": negative_prompt,
233
- "num_frames": num_frames_padded_target, # Always generate padded num_frames
234
- "frame_rate": 30, # Example, or get from UI if available
235
- "guidance_scale": guidance_scale,
236
- "generator": generator,
237
- "conditioning_items": conditioning_items_list if conditioning_items_list else None,
238
- "skip_layer_strategy": SkipLayerStrategy.AttentionValues, # Default or from YAML
239
- "offload_to_cpu": False, # Managed by global DEVICE
240
- "is_video": True,
241
- "vae_per_channel_normalize": True, # Common default
242
- "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "bfloat16"),
243
- "enhance_prompt": False, # Controlled by Gradio app logic if needed for full LTX script
244
- "image_cond_noise_scale": 0.025, # from YAML decode_noise_scale, or make it a param
245
- "media_items": input_media_for_vid2vid if mode == "video-to-video" else None,
246
- # "decode_timestep" and "decode_noise_scale" are part of first_pass/second_pass or direct call
 
 
 
247
  }
248
 
249
- # --- Generation ---
250
- if improve_texture:
251
- print("Using LTXMultiScalePipeline for generation...")
252
- # Ensure first_pass_config and second_pass_config have necessary overrides
253
- # The 'steps' from Gradio applies to the first pass's num_inference_steps if timesteps not set
254
- if "timesteps" not in first_pass_config:
255
- first_pass_config["num_inference_steps"] = steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- first_pass_config.setdefault("decode_timestep", PIPELINE_CONFIG_YAML.get("decode_timestep", 0.05))
258
- first_pass_config.setdefault("decode_noise_scale", PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025))
259
- second_pass_config.setdefault("decode_timestep", PIPELINE_CONFIG_YAML.get("decode_timestep", 0.05))
260
- second_pass_config.setdefault("decode_noise_scale", PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025))
261
-
262
- # The multi_scale_pipe's __call__ expects width and height for the *initial* (downscaled) generation
263
- result_frames_tensor = multi_scale_pipe(
264
- **shared_pipeline_args,
265
- width=initial_gen_width,
266
- height=initial_gen_height,
267
- downscale_factor=downscale_factor, # This might be used internally by multi_scale_pipe
268
- first_pass=first_pass_config,
269
- second_pass=second_pass_config,
270
- output_type="pt" # Get tensor for further processing
271
- ).images
272
 
273
- # LTXMultiScalePipeline should return images at 2x the initial_gen_width/height
274
- # So, result_frames_tensor is at initial_gen_width*2, initial_gen_height*2
275
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  else:
277
- print("Using LTXVideoPipeline (first pass) + Manual Upsample + Decode...")
278
- # 1. First pass generation at downscaled resolution
279
- if "timesteps" not in first_pass_config:
280
- first_pass_config["num_inference_steps"] = steps
281
-
282
- first_pass_args = {
283
- **shared_pipeline_args,
284
- **first_pass_config,
285
- "width": initial_gen_width,
286
- "height": initial_gen_height,
287
- "output_type": "latent"
288
- }
289
- latents = pipe(**first_pass_args).images # .images here is actually latents
290
- print("First pass done!")
291
- # 2. Upsample latents manually
292
- # Need to handle normalization around latent upsampler if it expects unnormalized latents
293
- latents_unnorm = un_normalize_latents(latents, pipe.vae, vae_per_channel_normalize=True)
294
- upsampled_latents_unnorm = latent_upsampler(latents_unnorm)
295
- upsampled_latents = normalize_latents(upsampled_latents_unnorm, pipe.vae, vae_per_channel_normalize=True)
296
 
297
- # 3. Decode upsampled latents
298
- # The upsampler typically doubles the spatial dimensions
299
- upscaled_height_for_decode = initial_gen_height * 2
300
- upscaled_width_for_decode = initial_gen_width * 2
 
 
 
 
301
 
302
- # Prepare target_shape for VAE decoder
303
- # batch_size, channels, num_frames, height, width
304
- # Latents are (B, C, F_latent, H_latent, W_latent)
305
- # Target shape for vae.decode is pixel space
306
- # num_video_frames_final = upsampled_latents.shape[2] * pipe.vae.temporal_downscale_factor
307
- # if causal, it might be (F_latent - 1) * factor + 1
308
- num_video_frames_final = (upsampled_latents.shape[2] -1) * pipe.vae.temporal_downscale_factor + 1
309
-
310
-
311
- decode_kwargs = {
312
- "target_shape": (
313
- upsampled_latents.shape[0], # batch
314
- 3, # out channels
315
- num_video_frames_final,
316
- upscaled_height_for_decode,
317
- upscaled_width_for_decode
318
- )
319
- }
320
- if pipe.vae.decoder.timestep_conditioning:
321
- decode_kwargs["timestep"] = torch.tensor([PIPELINE_CONFIG_YAML.get("decode_timestep", 0.05)] * upsampled_latents.shape[0]).to(DEVICE)
322
- # Add noise for decode if specified, similar to LTXVideoPipeline's call
323
- noise = torch.randn_like(upsampled_latents)
324
- decode_noise_val = PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025)
325
- upsampled_latents = upsampled_latents * (1 - decode_noise_val) + noise * decode_noise_val
326
-
327
- print("before vae decoding")
328
- result_frames_tensor = pipe.vae.decode(upsampled_latents, **decode_kwargs).sample
329
- print("after vae decoding?")
330
- # result_frames_tensor shape: (B, C, F_video, H_video, W_video)
331
-
332
- # --- Post-processing: Cropping and Converting to PIL ---
333
- # Crop to original num_frames (before padding)
334
- result_frames_tensor = result_frames_tensor[:, :, :num_frames, :, :]
335
-
336
- # Unpad to target height and width
337
- _, _, _, current_h, current_w = result_frames_tensor.shape
338
-
339
- # Calculate crop needed if current dimensions are larger than padded_target
340
- # This happens if multi_scale_pipe output is larger than height_padded_target
341
- crop_y_start = (current_h - height_padded_target) // 2
342
- crop_x_start = (current_w - width_padded_target) // 2
343
-
344
- result_frames_tensor = result_frames_tensor[
345
- :, :, :,
346
- crop_y_start : crop_y_start + height_padded_target,
347
- crop_x_start : crop_x_start + width_padded_target
348
- ]
349
 
350
- # Now remove the padding added for VAE compatibility
351
- pad_left, pad_right, pad_top, pad_bottom = padding_target
352
- unpad_bottom = -pad_bottom if pad_bottom > 0 else result_frames_tensor.shape[3]
353
- unpad_right = -pad_right if pad_right > 0 else result_frames_tensor.shape[4]
354
-
355
- result_frames_tensor = result_frames_tensor[
356
- :, :, :,
357
- pad_top : unpad_bottom,
358
- pad_left : unpad_right
359
  ]
360
 
 
 
 
 
361
 
362
- # Convert tensor to list of PIL Images
363
- video_pil_list = []
364
- # result_frames_tensor shape: (B, C, F, H, W)
365
- # We expect B=1 from typical generation
366
- video_single_batch = result_frames_tensor[0] # Shape: (C, F, H, W)
367
- video_single_batch = (video_single_batch / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
368
- video_single_batch = video_single_batch.permute(1, 2, 3, 0).cpu().float().numpy() # F, H, W, C
369
 
370
- for frame_idx in range(video_single_batch.shape[0]):
371
- frame_np = (video_single_batch[frame_idx] * 255).astype(np.uint8)
372
- video_pil_list.append(Image.fromarray(frame_np))
373
-
374
- # Save video
375
- output_video_path = "output.mp4" # Gradio handles temp files
376
- export_to_video(video_pil_list, output_video_path, fps=24) # Assuming fps from original script
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  return output_video_path
378
 
379
-
380
  css="""
381
  #col-container {
382
  margin: 0 auto;
@@ -384,89 +304,82 @@ css="""
384
  }
385
  """
386
 
387
- with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo:
388
  gr.Markdown("# LTX Video 0.9.7 Distilled (using LTX-Video lib)")
 
389
  with gr.Row():
390
  with gr.Column():
391
  with gr.Group():
392
  with gr.Tab("text-to-video") as text_tab:
393
- image_n = gr.Image(label="", visible=False, value=None) # Ensure None for path
394
- video_n = gr.Video(label="", visible=False, value=None) # Ensure None for path
395
- t2v_prompt = gr.Textbox(label="prompt", value="A majestic dragon flying over a medieval castle")
396
- t2v_button = gr.Button("Generate Text-to-Video")
 
397
  with gr.Tab("image-to-video") as image_tab:
398
- video_i = gr.Video(label="", visible=False, value=None)
399
- image_i2v = gr.Image(label="input image", type="filepath")
400
- i2v_prompt = gr.Textbox(label="prompt", value="The creature from the image starts to move")
401
- i2v_button = gr.Button("Generate Image-to-Video")
402
  with gr.Tab("video-to-video") as video_tab:
403
- image_v = gr.Image(label="", visible=False, value=None)
404
- video_v2v = gr.Video(label="input video")
405
- frames_to_use = gr.Number(label="num frames to use",info="first # of frames to use from the input video for conditioning/transformation", value=9)
406
- v2v_prompt = gr.Textbox(label="prompt", value="Change the style to cinematic anime")
407
- v2v_button = gr.Button("Generate Video-to-Video")
408
 
409
- improve_texture = gr.Checkbox(label="improve texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower.")
410
 
411
  with gr.Column():
412
- output = gr.Video(interactive=False)
 
413
 
414
  with gr.Accordion("Advanced settings", open=False):
415
- negative_prompt_input = gr.Textbox(label="negative prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted")
416
  with gr.Row():
417
- seed_input = gr.Number(label="seed", value=42, precision=0)
418
- randomize_seed_input = gr.Checkbox(label="randomize seed", value=False)
419
  with gr.Row():
420
- guidance_scale_input = gr.Slider(label="guidance scale", minimum=0, maximum=10, value=1.0, step=0.1, info="For distilled models, CFG is often 1.0 (disabled) or very low.") # Distilled model might not need high CFG
421
- steps_input = gr.Slider(label="Steps (for first pass if multi-scale)", minimum=1, maximum=30, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps", [1]*8).__len__(), step=1, info="Number of inference steps. If YAML defines timesteps, this is ignored for that pass.") # Default to length of first_pass timesteps
422
- num_frames_input = gr.Slider(label="# frames", minimum=9, maximum=121, value=25, step=8, info="Should be N*8+1, e.g., 9, 17, 25...") # Adjusted for LTX structure
 
 
423
  with gr.Row():
424
- height_input = gr.Slider(label="height", value=512, step=8, minimum=256, maximum=MAX_IMAGE_SIZE) # Step by VAE factor
425
- width_input = gr.Slider(label="width", value=704, step=8, minimum=256, maximum=MAX_IMAGE_SIZE) # Step by VAE factor
426
-
427
- t2v_button.click(fn=generate,
428
- inputs=[t2v_prompt,
429
- negative_prompt_input,
430
- image_n, # Pass None for image
431
- video_n, # Pass None for video
432
- height_input,
433
- width_input,
434
- gr.State("text-to-video"),
435
- steps_input,
436
- num_frames_input,
437
- gr.State(0), # frames_to_use not relevant for t2v
438
- seed_input,
439
- randomize_seed_input, guidance_scale_input, improve_texture],
440
- outputs=[output])
441
-
442
- i2v_button.click(fn=generate,
443
- inputs=[i2v_prompt,
444
- negative_prompt_input,
445
- image_i2v,
446
- video_i, # Pass None for video
447
- height_input,
448
- width_input,
449
- gr.State("image-to-video"),
450
- steps_input,
451
- num_frames_input,
452
- gr.State(0), # frames_to_use not relevant for i2v initial frame
453
- seed_input,
454
- randomize_seed_input, guidance_scale_input, improve_texture],
455
- outputs=[output])
456
-
457
- v2v_button.click(fn=generate,
458
- inputs=[v2v_prompt,
459
- negative_prompt_input,
460
- image_v, # Pass None for image
461
- video_v2v,
462
- height_input,
463
- width_input,
464
- gr.State("video-to-video"),
465
- steps_input,
466
- num_frames_input,
467
- frames_to_use,
468
- seed_input,
469
- randomize_seed_input, guidance_scale_input, improve_texture],
470
- outputs=[output])
471
-
472
- demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
+ import random
5
  import os
6
  import yaml
 
 
 
7
  from pathlib import Path
8
+ import imageio
9
+ import tempfile
10
+ from PIL import Image
11
  from huggingface_hub import hf_hub_download
12
+ import shutil
13
 
14
+ # --- Import necessary classes from the provided files ---
 
 
 
 
 
 
15
  from inference import (
16
  create_ltx_video_pipeline,
17
  create_latent_upsampler,
18
+ load_image_to_tensor_with_resize_and_crop,
 
 
19
  seed_everething,
20
+ get_device,
21
  calculate_padding,
22
+ load_media_file
23
  )
24
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
25
  from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ # --- Global constants from user's request and YAML ---
28
+ YAML_CONFIG_STRING = """
29
+ pipeline_type: multi-scale
30
+ checkpoint_path: "ltxv-13b-0.9.7-distilled.safetensors" # This will be replaced by the rc3 version
31
+ downscale_factor: 0.6666666
32
+ spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
33
+ stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
34
+ decode_timestep: 0.05
35
+ decode_noise_scale: 0.025
36
+ text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
37
+ precision: "bfloat16"
38
+ sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
39
+ prompt_enhancement_words_threshold: 120
40
+ prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
41
+ prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
42
+ stochastic_sampling: false
43
+
44
+ first_pass:
45
+ timesteps: [1.0000, 0.9937, 0.9875, 0.9812, 0.9750, 0.9094, 0.7250]
46
+ guidance_scale: 1
47
+ stg_scale: 0
48
+ rescaling_scale: 1
49
+ skip_block_list: [42]
50
+
51
+ second_pass:
52
+ timesteps: [0.9094, 0.7250, 0.4219]
53
+ guidance_scale: 1
54
+ stg_scale: 0
55
+ rescaling_scale: 1
56
+ skip_block_list: [42]
57
+ """
58
+ PIPELINE_CONFIG_YAML = yaml.safe_load(YAML_CONFIG_STRING)
59
+
60
+ # Model specific paths (to be downloaded)
61
+ DISTILLED_MODEL_REPO = "LTX-Colab/LTX-Video-Preview"
62
+ DISTILLED_MODEL_FILENAME = "ltxv-13b-0.9.7-distilled-rc3.safetensors"
63
+
64
+ UPSCALER_REPO = "Lightricks/LTX-Video"
65
+ # SPATIAL_UPSCALER_FILENAME will be taken from PIPELINE_CONFIG_YAML after it's loaded
66
+
67
+ MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280) # Max width/height from UI
68
+ MAX_NUM_FRAMES = 257 # From inference.py
69
+
70
+ # --- Global variables for loaded models ---
71
+ pipeline_instance = None
72
+ latent_upsampler_instance = None
73
+ current_device = get_device()
74
+ models_dir = "downloaded_models_gradio" # Use a distinct name
75
+ Path(models_dir).mkdir(parents=True, exist_ok=True)
76
+
77
+ # Download models and update config paths
78
+ print(f"Using device: {current_device}")
79
+ print("Downloading models...")
80
+
81
+ distilled_model_actual_path = hf_hub_download(
82
+ repo_id=DISTILLED_MODEL_REPO,
83
+ filename=DISTILLED_MODEL_FILENAME,
84
+ local_dir=models_dir,
85
+ local_dir_use_symlinks=False
86
+ )
87
+ PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
88
+ print(f"Distilled model downloaded to: {distilled_model_actual_path}")
89
 
 
 
90
  SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
91
+ spatial_upscaler_actual_path = hf_hub_download(
92
+ repo_id=UPSCALER_REPO,
93
+ filename=SPATIAL_UPSCALER_FILENAME,
94
+ local_dir=models_dir,
95
+ local_dir_use_symlinks=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  )
97
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
98
+ print(f"Spatial upscaler model downloaded to: {spatial_upscaler_actual_path}")
99
 
100
+ # Load pipelines
101
+ print("Creating LTX Video pipeline...")
102
+ pipeline_instance = create_ltx_video_pipeline(
103
+ ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
104
+ precision=PIPELINE_CONFIG_YAML["precision"],
105
+ text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
106
+ sampler=PIPELINE_CONFIG_YAML["sampler"],
107
+ device=current_device,
108
+ enhance_prompt=False, # Prompt enhancement handled by UI choice / Gradio logic if desired
109
+ prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
110
+ prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
111
  )
112
+ print("LTX Video pipeline created.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
115
+ print("Creating latent upsampler...")
116
+ latent_upsampler_instance = create_latent_upsampler(
117
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
118
+ device=current_device
119
+ )
120
+ print("Latent upsampler created.")
121
 
 
 
 
122
 
123
+ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath,
124
+ height_ui, width_ui, mode,
125
+ ui_steps, num_frames_ui,
126
+ ui_frames_to_use,
127
+ seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
128
+ progress=gr.Progress(track_ τότε=True)):
129
 
130
+ if randomize_seed:
131
+ seed_ui = random.randint(0, 2**32 - 1)
132
+ seed_everething(int(seed_ui))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ actual_height = int(height_ui)
135
+ actual_width = int(width_ui)
136
+ actual_num_frames = int(num_frames_ui)
137
+
138
+ # Padded dimensions for pipeline
139
+ height_padded = ((actual_height - 1) // 32 + 1) * 32
140
+ width_padded = ((actual_width - 1) // 32 + 1) * 32
141
+ num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
144
+
145
+ call_kwargs = {
146
  "prompt": prompt,
147
  "negative_prompt": negative_prompt,
148
+ "height": height_padded, # Use padded for pipeline
149
+ "width": width_padded, # Use padded for pipeline
150
+ "num_frames": num_frames_padded, # Use padded for pipeline
151
+ "frame_rate": 30,
152
+ "generator": torch.Generator(device=current_device).manual_seed(int(seed_ui)),
153
+ "output_type": "pt",
154
+ "conditioning_items": None,
155
+ "media_items": None,
156
+ "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
157
+ "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
158
+ "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"],
159
+ "image_cond_noise_scale": 0.15, # from inference.py defaults
160
+ "is_video": True, # Assume video output
161
+ "vae_per_channel_normalize": True, # from inference.py defaults
162
+ "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
163
+ "offload_to_cpu": False, # For Gradio, keep on device
164
+ "enhance_prompt": False, # Assuming no UI for this yet, stick to YAML or handle separately
165
  }
166
 
167
+ stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
168
+ if stg_mode_str.lower() in ["stg_av", "attention_values"]:
169
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
170
+ elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
171
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
172
+ elif stg_mode_str.lower() in ["stg_r", "residual"]:
173
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
174
+ elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
175
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
176
+ else:
177
+ raise ValueError(f"Invalid stg_mode: {stg_mode_str}")
178
+
179
+ if mode == "image-to-video" and input_image_filepath:
180
+ try:
181
+ # Ensure the input image is loaded with original H/W for correct aspect ratio handling by the function
182
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
183
+ input_image_filepath, actual_height, actual_width
184
+ )
185
+ media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
186
+ call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(current_device), 0, 1.0)]
187
+ except Exception as e:
188
+ print(f"Error loading image {input_image_filepath}: {e}")
189
+ raise gr.Error(f"Could not load image: {e}")
190
+
191
+
192
+ elif mode == "video-to-video" and input_video_filepath:
193
+ try:
194
+ call_kwargs["media_items"] = load_media_file(
195
+ media_path=input_video_filepath,
196
+ height=actual_height,
197
+ width=actual_width,
198
+ max_frames=int(ui_frames_to_use),
199
+ padding=padding_values
200
+ ).to(current_device)
201
+ except Exception as e:
202
+ print(f"Error loading video {input_video_filepath}: {e}")
203
+ raise gr.Error(f"Could not load video: {e}")
204
+
205
+ # Multi-scale or single-scale pipeline call
206
+ if improve_texture_flag:
207
+ if not latent_upsampler_instance:
208
+ raise gr.Error("Spatial upscaler model not loaded, cannot use multi-scale.")
209
 
210
+ multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, latent_upsampler_instance)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ # Prepare pass-specific arguments, overriding with UI inputs where appropriate
213
+ first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
214
+ first_pass_args["guidance_scale"] = float(ui_guidance_scale)
215
+ if "timesteps" not in first_pass_args: # Only if YAML doesn't define timesteps
216
+ first_pass_args["num_inference_steps"] = int(ui_steps)
217
+
218
+ second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
219
+ second_pass_args["guidance_scale"] = float(ui_guidance_scale)
220
+ # num_inference_steps for second pass is typically determined by its YAML timesteps
221
+
222
+ multi_scale_call_kwargs = call_kwargs.copy()
223
+ multi_scale_call_kwargs.update({
224
+ "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
225
+ "first_pass": first_pass_args,
226
+ "second_pass": second_pass_args,
227
+ })
228
+
229
+ print(f"Calling multi-scale pipeline with effective height={actual_height}, width={actual_width}")
230
+ result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
231
  else:
232
+ # Single pass call (using base pipeline)
233
+ single_pass_call_kwargs = call_kwargs.copy()
234
+ single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ # For single pass, if YAML doesn't have top-level timesteps, use ui_steps
237
+ # The current YAML is multi-scale focused, so it lacks top-level step control.
238
+ # We'll assume for a base call, num_inference_steps is directly taken from UI.
239
+ single_pass_call_kwargs["num_inference_steps"] = int(ui_steps)
240
+ # Remove pass-specific args if they accidentally slipped in
241
+ single_pass_call_kwargs.pop("first_pass", None)
242
+ single_pass_call_kwargs.pop("second_pass", None)
243
+ single_pass_call_kwargs.pop("downscale_factor", None)
244
 
245
+ print(f"Calling base pipeline with height={height_padded}, width={width_padded}")
246
+ result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
247
+
248
+ # Crop to original requested dimensions (num_frames, height, width)
249
+ # Padding: (pad_left, pad_right, pad_top, pad_bottom)
250
+ pad_left, pad_right, pad_top, pad_bottom = padding_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ # Calculate slice indices, ensuring they don't go negative if padding was zero
253
+ slice_h_end = -pad_bottom if pad_bottom > 0 else None
254
+ slice_w_end = -pad_right if pad_right > 0 else None
255
+
256
+ result_images_tensor = result_images_tensor[
257
+ :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
 
 
 
258
  ]
259
 
260
+ # Convert tensor to video file
261
+ video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
262
+ video_np = np.clip(video_np * 0.5 + 0.5, 0, 1) # from [-1,1] to [0,1]
263
+ video_np = (video_np * 255).astype(np.uint8)
264
 
265
+ temp_dir = tempfile.mkdtemp()
266
+ timestamp = random.randint(10000,99999) # Add timestamp to avoid caching issues
267
+ output_video_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
 
 
 
 
268
 
269
+ try:
270
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
271
+ for frame_idx in range(video_np.shape[0]):
272
+ progress(frame_idx / video_np.shape[0], desc="Saving video")
273
+ video_writer.append_data(video_np[frame_idx])
274
+ except Exception as e:
275
+ print(f"Error saving video: {e}")
276
+ # Fallback to saving frame by frame if container issue
277
+ try:
278
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8, macro_block_size=None) as video_writer:
279
+ for frame_idx in range(video_np.shape[0]):
280
+ progress(frame_idx / video_np.shape[0], desc="Saving video (fallback)")
281
+ video_writer.append_data(video_np[frame_idx])
282
+ except Exception as e2:
283
+ print(f"Fallback video saving error: {e2}")
284
+ raise gr.Error(f"Failed to save video: {e2}")
285
+
286
+
287
+ # Clean up temporary image/video files if they were created by Gradio
288
+ if isinstance(input_image_filepath, tempfile._TemporaryFileWrapper):
289
+ input_image_filepath.close()
290
+ if os.path.exists(input_image_filepath.name):
291
+ os.remove(input_image_filepath.name)
292
+ if isinstance(input_video_filepath, tempfile._TemporaryFileWrapper):
293
+ input_video_filepath.close()
294
+ if os.path.exists(input_video_filepath.name):
295
+ os.remove(input_video_filepath.name)
296
+
297
  return output_video_path
298
 
299
+ # --- Gradio UI Definition (from user) ---
300
  css="""
301
  #col-container {
302
  margin: 0 auto;
 
304
  }
305
  """
306
 
307
+ with gr.Blocks(css=css, theme=gr.themes.Glass()) as demo: # Changed theme for variety
308
  gr.Markdown("# LTX Video 0.9.7 Distilled (using LTX-Video lib)")
309
+ gr.Markdown("Generates a short video based on text prompt, image, or existing video.")
310
  with gr.Row():
311
  with gr.Column():
312
  with gr.Group():
313
  with gr.Tab("text-to-video") as text_tab:
314
+ # Hidden inputs for consistent generate() signature
315
+ image_n_hidden = gr.Textbox(label="image_n", visible=False, value=None)
316
+ video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
317
+ t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
318
+ t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
319
  with gr.Tab("image-to-video") as image_tab:
320
+ video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
321
+ image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam"])
322
+ i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
323
+ i2v_button = gr.Button("Generate Image-to-Video", variant="primary")
324
  with gr.Tab("video-to-video") as video_tab:
325
+ image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
326
+ video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"])
327
+ frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames to use for conditioning/transformation. Must be N*8+1.")
328
+ v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
329
+ v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
330
 
331
+ improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
332
 
333
  with gr.Column():
334
+ output_video = gr.Video(label="Generated Video", interactive=False)
335
+ gr.Markdown("Note: Generation can take a few minutes depending on settings and hardware.")
336
 
337
  with gr.Accordion("Advanced settings", open=False):
338
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
339
  with gr.Row():
340
+ seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
341
+ randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=False)
342
  with gr.Row():
343
+ # For distilled models, CFG is often 1.0 (disabled) or very low.
344
+ guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
345
+ # Default to length of first_pass timesteps, if available
346
+ default_steps = len(PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps", [1]*7)) # Fallback to 7 if not defined
347
+ steps_input = gr.Slider(label="Inference Steps (for first pass if multi-scale)", minimum=1, maximum=30, value=default_steps, step=1, info="Number of denoising steps. More steps can improve quality but increase time. If YAML defines 'timesteps' for a pass, this UI value is ignored for that pass.")
348
  with gr.Row():
349
+ num_frames_input = gr.Slider(label="Number of Frames to Generate", minimum=9, maximum=MAX_NUM_FRAMES, value=25, step=8, info="Total frames in the output video. Should be N*8+1 (e.g., 9, 17, 25...).")
350
+ with gr.Row():
351
+ height_input = gr.Slider(label="Height", value=512, step=32, minimum=256, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
352
+ width_input = gr.Slider(label="Width", value=704, step=32, minimum=256, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
353
+
354
+ # Define click actions
355
+ # Note: gr.State passes the current value of the component without creating a UI element for it.
356
+ # We use hidden Textbox inputs for image_n, video_n etc. and pass their `value` (which is None)
357
+ # to ensure the `generate` function always receives these arguments.
358
+
359
+ t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
360
+ height_input, width_input, gr.State("text-to-video"),
361
+ steps_input, num_frames_input, gr.State(0), # frames_to_use not relevant for t2v
362
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
363
+
364
+ i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
365
+ height_input, width_input, gr.State("image-to-video"),
366
+ steps_input, num_frames_input, gr.State(0), # frames_to_use not relevant for i2v initial frame
367
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
368
+
369
+ v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
370
+ height_input, width_input, gr.State("video-to-video"),
371
+ steps_input, num_frames_input, frames_to_use,
372
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
373
+
374
+ t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video])
375
+ i2v_button.click(fn=generate, inputs=i2v_inputs, outputs=[output_video])
376
+ v2v_button.click(fn=generate, inputs=v2v_inputs, outputs=[output_video])
377
+
378
+ if __name__ == "__main__":
379
+ # Clean up old model directory if it exists from previous runs
380
+ if os.path.exists(models_dir) and os.path.isdir(models_dir):
381
+ print(f"Cleaning up old model directory: {models_dir}")
382
+ # shutil.rmtree(models_dir) # Optional: uncomment to force re-download on every run
383
+ Path(models_dir).mkdir(parents=True, exist_ok=True)
384
+
385
+ demo.queue().launch(debug=True, share=False)