Fabrice-TIERCELIN commited on
Commit
bddcb24
·
verified ·
1 Parent(s): ee1b342

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +10 -18
  2. app.py +486 -864
  3. inference.py +774 -0
  4. requirements.txt +15 -48
README.md CHANGED
@@ -1,21 +1,13 @@
1
  ---
2
- title: SUPIR Image Upscaler
 
 
 
3
  sdk: gradio
4
- emoji: 📷
5
- sdk_version: 4.38.1
6
  app_file: app.py
7
- license: mit
8
- colorFrom: blue
9
- colorTo: pink
10
- tags:
11
- - Upscaling
12
- - Restoring
13
- - Image-to-Image
14
- - Image-2-Image
15
- - Img-to-Img
16
- - Img-2-Img
17
- - language models
18
- - LLMs
19
- short_description: Restore blurred or small images with prompt
20
- suggested_hardware: zero-a10g
21
- ---
 
1
  ---
2
+ title: LTX Video Fast
3
+ emoji: 🎥
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.29.1
 
8
  app_file: app.py
9
+ pinned: false
10
+ short_description: ultra-fast video model, LTX 0.9.7 13B distilled
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,864 +1,486 @@
1
- import os
2
- import gradio as gr
3
- import argparse
4
- import numpy as np
5
- import torch
6
- import einops
7
- import copy
8
- import math
9
- import time
10
- import random
11
- import spaces
12
- import re
13
- import uuid
14
-
15
- from gradio_imageslider import ImageSlider
16
- from PIL import Image
17
- from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype, create_SUPIR_model, load_QF_ckpt
18
- from huggingface_hub import hf_hub_download
19
- from pillow_heif import register_heif_opener
20
-
21
- register_heif_opener()
22
-
23
- max_64_bit_int = np.iinfo(np.int32).max
24
-
25
- hf_hub_download(repo_id="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", filename="open_clip_pytorch_model.bin", local_dir="laion_CLIP-ViT-bigG-14-laion2B-39B-b160k")
26
- hf_hub_download(repo_id="camenduru/SUPIR", filename="sd_xl_base_1.0_0.9vae.safetensors", local_dir="yushan777_SUPIR")
27
- hf_hub_download(repo_id="camenduru/SUPIR", filename="SUPIR-v0F.ckpt", local_dir="yushan777_SUPIR")
28
- hf_hub_download(repo_id="camenduru/SUPIR", filename="SUPIR-v0Q.ckpt", local_dir="yushan777_SUPIR")
29
- hf_hub_download(repo_id="RunDiffusion/Juggernaut-XL-Lightning", filename="Juggernaut_RunDiffusionPhoto2_Lightning_4Steps.safetensors", local_dir="RunDiffusion_Juggernaut-XL-Lightning")
30
-
31
- parser = argparse.ArgumentParser()
32
- parser.add_argument("--opt", type=str, default='options/SUPIR_v0.yaml')
33
- parser.add_argument("--ip", type=str, default='127.0.0.1')
34
- parser.add_argument("--port", type=int, default='6688')
35
- parser.add_argument("--no_llava", action='store_true', default=True)#False
36
- parser.add_argument("--use_image_slider", action='store_true', default=False)#False
37
- parser.add_argument("--log_history", action='store_true', default=False)
38
- parser.add_argument("--loading_half_params", action='store_true', default=False)#False
39
- parser.add_argument("--use_tile_vae", action='store_true', default=True)#False
40
- parser.add_argument("--encoder_tile_size", type=int, default=512)
41
- parser.add_argument("--decoder_tile_size", type=int, default=64)
42
- parser.add_argument("--load_8bit_llava", action='store_true', default=False)
43
- args = parser.parse_args()
44
-
45
- if torch.cuda.device_count() > 0:
46
- SUPIR_device = 'cuda:0'
47
-
48
- # Load SUPIR
49
- model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
50
- if args.loading_half_params:
51
- model = model.half()
52
- if args.use_tile_vae:
53
- model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
54
- model = model.to(SUPIR_device)
55
- model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
56
- model.current_model = 'v0-Q'
57
- ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
58
-
59
- def check_upload(input_image):
60
- if input_image is None:
61
- raise gr.Error("Please provide an image to restore.")
62
- return gr.update(visible = True)
63
-
64
- def update_seed(is_randomize_seed, seed):
65
- if is_randomize_seed:
66
- return random.randint(0, max_64_bit_int)
67
- return seed
68
-
69
- def reset():
70
- return [
71
- None,
72
- 0,
73
- None,
74
- None,
75
- "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
76
- "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
77
- 1,
78
- 1024,
79
- 1,
80
- 2,
81
- 50,
82
- -1.0,
83
- 1.,
84
- default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0,
85
- True,
86
- random.randint(0, max_64_bit_int),
87
- 5,
88
- 1.003,
89
- "Wavelet",
90
- "fp32",
91
- "fp32",
92
- 1.0,
93
- True,
94
- False,
95
- default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0,
96
- 0.,
97
- "v0-Q",
98
- "input",
99
- 179
100
- ]
101
-
102
- def check_and_update(input_image):
103
- if input_image is None:
104
- raise gr.Error("Please provide an image to restore.")
105
- return gr.update(visible = True)
106
-
107
- @spaces.GPU(duration=420)
108
- def stage1_process(
109
- input_image,
110
- gamma_correction,
111
- diff_dtype,
112
- ae_dtype
113
- ):
114
- print('stage1_process ==>>')
115
- if torch.cuda.device_count() == 0:
116
- gr.Warning('Set this space to GPU config to make it work.')
117
- return None, None
118
- torch.cuda.set_device(SUPIR_device)
119
- LQ = HWC3(np.array(Image.open(input_image)))
120
- LQ = fix_resize(LQ, 512)
121
- # stage1
122
- LQ = np.array(LQ) / 255 * 2 - 1
123
- LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
124
-
125
- model.ae_dtype = convert_dtype(ae_dtype)
126
- model.model.dtype = convert_dtype(diff_dtype)
127
-
128
- LQ = model.batchify_denoise(LQ, is_stage1=True)
129
- LQ = (LQ[0].permute(1, 2, 0) * 127.5 + 127.5).cpu().numpy().round().clip(0, 255).astype(np.uint8)
130
- # gamma correction
131
- LQ = LQ / 255.0
132
- LQ = np.power(LQ, gamma_correction)
133
- LQ *= 255.0
134
- LQ = LQ.round().clip(0, 255).astype(np.uint8)
135
- print('<<== stage1_process')
136
- return LQ, gr.update(visible = True)
137
-
138
- def stage2_process(*args, **kwargs):
139
- try:
140
- return restore_in_Xmin(*args, **kwargs)
141
- except Exception as e:
142
- # NO_GPU_MESSAGE_INQUEUE
143
- print("gradio.exceptions.Error 'No GPU is currently available for you after 60s'")
144
- print('str(type(e)): ' + str(type(e))) # <class 'gradio.exceptions.Error'>
145
- print('str(e): ' + str(e)) # You have exceeded your GPU quota...
146
- try:
147
- print('e.message: ' + e.message) # No GPU is currently available for you after 60s
148
- except Exception as e2:
149
- print('Failure')
150
- if str(e).startswith("No GPU is currently available for you after 60s"):
151
- print('Exception identified!!!')
152
- #if str(type(e)) == "<class 'gradio.exceptions.Error'>":
153
- #print('Exception of name ' + type(e).__name__)
154
- raise e
155
-
156
- def restore_in_Xmin(
157
- noisy_image,
158
- rotation,
159
- denoise_image,
160
- prompt,
161
- a_prompt,
162
- n_prompt,
163
- num_samples,
164
- min_size,
165
- downscale,
166
- upscale,
167
- edm_steps,
168
- s_stage1,
169
- s_stage2,
170
- s_cfg,
171
- randomize_seed,
172
- seed,
173
- s_churn,
174
- s_noise,
175
- color_fix_type,
176
- diff_dtype,
177
- ae_dtype,
178
- gamma_correction,
179
- linear_CFG,
180
- linear_s_stage2,
181
- spt_linear_CFG,
182
- spt_linear_s_stage2,
183
- model_select,
184
- output_format,
185
- allocation
186
- ):
187
- print("noisy_image:\n" + str(noisy_image))
188
- print("denoise_image:\n" + str(denoise_image))
189
- print("rotation: " + str(rotation))
190
- print("prompt: " + str(prompt))
191
- print("a_prompt: " + str(a_prompt))
192
- print("n_prompt: " + str(n_prompt))
193
- print("num_samples: " + str(num_samples))
194
- print("min_size: " + str(min_size))
195
- print("downscale: " + str(downscale))
196
- print("upscale: " + str(upscale))
197
- print("edm_steps: " + str(edm_steps))
198
- print("s_stage1: " + str(s_stage1))
199
- print("s_stage2: " + str(s_stage2))
200
- print("s_cfg: " + str(s_cfg))
201
- print("randomize_seed: " + str(randomize_seed))
202
- print("seed: " + str(seed))
203
- print("s_churn: " + str(s_churn))
204
- print("s_noise: " + str(s_noise))
205
- print("color_fix_type: " + str(color_fix_type))
206
- print("diff_dtype: " + str(diff_dtype))
207
- print("ae_dtype: " + str(ae_dtype))
208
- print("gamma_correction: " + str(gamma_correction))
209
- print("linear_CFG: " + str(linear_CFG))
210
- print("linear_s_stage2: " + str(linear_s_stage2))
211
- print("spt_linear_CFG: " + str(spt_linear_CFG))
212
- print("spt_linear_s_stage2: " + str(spt_linear_s_stage2))
213
- print("model_select: " + str(model_select))
214
- print("GPU time allocation: " + str(allocation) + " min")
215
- print("output_format: " + str(output_format))
216
-
217
- input_format = re.sub(r"^.*\.([^\.]+)$", r"\1", noisy_image)
218
-
219
- if input_format not in ['png', 'webp', 'jpg', 'jpeg', 'gif', 'bmp', 'heic']:
220
- gr.Warning('Invalid image format. Please first convert into *.png, *.webp, *.jpg, *.jpeg, *.gif, *.bmp or *.heic.')
221
- return None, None, None, None
222
-
223
- if output_format == "input":
224
- if noisy_image is None:
225
- output_format = "png"
226
- else:
227
- output_format = input_format
228
- print("final output_format: " + str(output_format))
229
-
230
- if prompt is None:
231
- prompt = ""
232
-
233
- if a_prompt is None:
234
- a_prompt = ""
235
-
236
- if n_prompt is None:
237
- n_prompt = ""
238
-
239
- if prompt != "" and a_prompt != "":
240
- a_prompt = prompt + ", " + a_prompt
241
- else:
242
- a_prompt = prompt + a_prompt
243
- print("Final prompt: " + str(a_prompt))
244
-
245
- denoise_image = np.array(Image.open(noisy_image if denoise_image is None else denoise_image))
246
-
247
- if rotation == 90:
248
- denoise_image = np.array(list(zip(*denoise_image[::-1])))
249
- elif rotation == 180:
250
- denoise_image = np.array(list(zip(*denoise_image[::-1])))
251
- denoise_image = np.array(list(zip(*denoise_image[::-1])))
252
- elif rotation == -90:
253
- denoise_image = np.array(list(zip(*denoise_image))[::-1])
254
-
255
- if 1 < downscale:
256
- input_height, input_width, input_channel = denoise_image.shape
257
- denoise_image = np.array(Image.fromarray(denoise_image).resize((input_width // downscale, input_height // downscale), Image.LANCZOS))
258
-
259
- denoise_image = HWC3(denoise_image)
260
-
261
- if torch.cuda.device_count() == 0:
262
- gr.Warning('Set this space to GPU config to make it work.')
263
- return [noisy_image, denoise_image], gr.update(label="Downloadable results in *." + output_format + " format", format = output_format, value = [denoise_image]), None, gr.update(visible=True)
264
-
265
- if model_select != model.current_model:
266
- print('load ' + model_select)
267
- if model_select == 'v0-Q':
268
- model.load_state_dict(ckpt_Q, strict=False)
269
- elif model_select == 'v0-F':
270
- model.load_state_dict(ckpt_F, strict=False)
271
- model.current_model = model_select
272
-
273
- model.ae_dtype = convert_dtype(ae_dtype)
274
- model.model.dtype = convert_dtype(diff_dtype)
275
-
276
- return restore_on_gpu(
277
- noisy_image, denoise_image, prompt, a_prompt, n_prompt, num_samples, min_size, downscale, upscale, edm_steps, s_stage1, s_stage2, s_cfg, randomize_seed, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, output_format, allocation
278
- )
279
-
280
- def get_duration(
281
- noisy_image,
282
- input_image,
283
- prompt,
284
- a_prompt,
285
- n_prompt,
286
- num_samples,
287
- min_size,
288
- downscale,
289
- upscale,
290
- edm_steps,
291
- s_stage1,
292
- s_stage2,
293
- s_cfg,
294
- randomize_seed,
295
- seed,
296
- s_churn,
297
- s_noise,
298
- color_fix_type,
299
- diff_dtype,
300
- ae_dtype,
301
- gamma_correction,
302
- linear_CFG,
303
- linear_s_stage2,
304
- spt_linear_CFG,
305
- spt_linear_s_stage2,
306
- model_select,
307
- output_format,
308
- allocation
309
- ):
310
- return allocation
311
-
312
- @spaces.GPU(duration=get_duration)
313
- def restore_on_gpu(
314
- noisy_image,
315
- input_image,
316
- prompt,
317
- a_prompt,
318
- n_prompt,
319
- num_samples,
320
- min_size,
321
- downscale,
322
- upscale,
323
- edm_steps,
324
- s_stage1,
325
- s_stage2,
326
- s_cfg,
327
- randomize_seed,
328
- seed,
329
- s_churn,
330
- s_noise,
331
- color_fix_type,
332
- diff_dtype,
333
- ae_dtype,
334
- gamma_correction,
335
- linear_CFG,
336
- linear_s_stage2,
337
- spt_linear_CFG,
338
- spt_linear_s_stage2,
339
- model_select,
340
- output_format,
341
- allocation
342
- ):
343
- start = time.time()
344
- print('restore ==>>')
345
-
346
- torch.cuda.set_device(SUPIR_device)
347
-
348
- with torch.no_grad():
349
- input_image = upscale_image(input_image, upscale, unit_resolution=32, min_size=min_size)
350
- LQ = np.array(input_image) / 255.0
351
- LQ = np.power(LQ, gamma_correction)
352
- LQ *= 255.0
353
- LQ = LQ.round().clip(0, 255).astype(np.uint8)
354
- LQ = LQ / 255 * 2 - 1
355
- LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
356
- captions = ['']
357
-
358
- samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
359
- s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
360
- num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
361
- use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
362
- cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
363
-
364
- x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
365
- 0, 255).astype(np.uint8)
366
- results = [x_samples[i] for i in range(num_samples)]
367
- torch.cuda.empty_cache()
368
-
369
- # All the results have the same size
370
- input_height, input_width, input_channel = np.array(input_image).shape
371
- result_height, result_width, result_channel = np.array(results[0]).shape
372
-
373
- print('<<== restore')
374
- end = time.time()
375
- secondes = int(end - start)
376
- minutes = math.floor(secondes / 60)
377
- secondes = secondes - (minutes * 60)
378
- hours = math.floor(minutes / 60)
379
- minutes = minutes - (hours * 60)
380
- information = ("Start the process again if you want a different result. " if randomize_seed else "") + \
381
- "If you don't get the image you wanted, add more details in the « Image description ». " + \
382
- "Wait " + str(allocation) + " min before a new run to avoid quota penalty or use another computer. " + \
383
- "The image" + (" has" if len(results) == 1 else "s have") + " been generated in " + \
384
- ((str(hours) + " h, ") if hours != 0 else "") + \
385
- ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + \
386
- str(secondes) + " sec. " + \
387
- "The new image resolution is " + str(result_width) + \
388
- " pixels large and " + str(result_height) + \
389
- " pixels high, so a resolution of " + f'{result_width * result_height:,}' + " pixels."
390
- print(information)
391
- try:
392
- print("Initial resolution: " + f'{input_width * input_height:,}')
393
- print("Final resolution: " + f'{result_width * result_height:,}')
394
- print("edm_steps: " + str(edm_steps))
395
- print("num_samples: " + str(num_samples))
396
- print("downscale: " + str(downscale))
397
- print("Estimated minutes: " + f'{(((result_width * result_height**(1/1.75)) * input_width * input_height * (edm_steps**(1/2)) * (num_samples**(1/2.5)))**(1/2.5)) / 25000:,}')
398
- except Exception as e:
399
- print('Exception of Estimation')
400
-
401
- # Only one image can be shown in the slider
402
- return [noisy_image] + [results[0]], gr.update(label="Downloadable results in *." + output_format + " format", format = output_format, value = results), gr.update(value = information, visible = True), gr.update(visible=True)
403
-
404
- def load_and_reset(param_setting):
405
- print('load_and_reset ==>>')
406
- if torch.cuda.device_count() == 0:
407
- gr.Warning('Set this space to GPU config to make it work.')
408
- return None, None, None, None, None, None, None, None, None, None, None, None, None, None
409
- edm_steps = default_setting.edm_steps
410
- s_stage2 = 1.0
411
- s_stage1 = -1.0
412
- s_churn = 5
413
- s_noise = 1.003
414
- a_prompt = 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - ' \
415
- 'realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore ' \
416
- 'detailing, hyper sharpness, perfect without deformations.'
417
- n_prompt = 'painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, ' \
418
- '3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, ' \
419
- 'signature, jpeg artifacts, deformed, lowres, over-smooth'
420
- color_fix_type = 'Wavelet'
421
- spt_linear_s_stage2 = 0.0
422
- linear_s_stage2 = False
423
- linear_CFG = True
424
- if param_setting == "Quality":
425
- s_cfg = default_setting.s_cfg_Quality
426
- spt_linear_CFG = default_setting.spt_linear_CFG_Quality
427
- model_select = "v0-Q"
428
- elif param_setting == "Fidelity":
429
- s_cfg = default_setting.s_cfg_Fidelity
430
- spt_linear_CFG = default_setting.spt_linear_CFG_Fidelity
431
- model_select = "v0-F"
432
- else:
433
- raise NotImplementedError
434
- gr.Info('The parameters are reset.')
435
- print('<<== load_and_reset')
436
- return edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt, color_fix_type, linear_CFG, \
437
- linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select
438
-
439
- def log_information(result_gallery):
440
- print('log_information')
441
- if result_gallery is not None:
442
- for i, result in enumerate(result_gallery):
443
- print(result[0])
444
-
445
- def on_select_result(result_slider, result_gallery, evt: gr.SelectData):
446
- print('on_select_result')
447
- if result_gallery is not None:
448
- for i, result in enumerate(result_gallery):
449
- print(result[0])
450
- return [result_slider[0], result_gallery[evt.index][0]]
451
-
452
- title_html = """
453
- <h1><center>SUPIR</center></h1>
454
- <big><center>Upscale your images up to x10 freely, without account, without watermark and download it</center></big>
455
- <center><big><big>🤸<big><big><big><big><big><big>🤸</big></big></big></big></big></big></big></big></center>
456
-
457
- <p>This is an online demo of SUPIR, a practicing model scaling for photo-realistic image restoration.
458
- The content added by SUPIR is <b><u>imagination, not real-world information</u></b>.
459
- SUPIR is for beauty and illustration only.
460
- Most of the processes last few minutes.
461
- If you want to upscale AI-generated images, be noticed that <i>PixArt Sigma</i> space can directly generate 5984x5984 images.
462
- Due to Gradio issues, the generated image is slightly less satured than the original.
463
- Please leave a <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">message in discussion</a> if you encounter issues.
464
- You can also use <a href="https://huggingface.co/spaces/gokaygokay/AuraSR">AuraSR</a> to upscale x4.
465
-
466
- <p><center><a href="https://arxiv.org/abs/2401.13627">Paper</a> &emsp; <a href="http://supir.xpixel.group/">Project Page</a> &emsp; <a href="https://huggingface.co/blog/MonsterMMORPG/supir-sota-image-upscale-better-than-magnific-ai">Local Install Guide</a></center></p>
467
- <p><center><a style="display:inline-block" href='https://github.com/Fanghua-Yu/SUPIR'><img alt="GitHub Repo stars" src="https://img.shields.io/github/stars/Fanghua-Yu/SUPIR?style=social"></a></center></p>
468
- """
469
-
470
-
471
- claim_md = """
472
- ## **Piracy**
473
- The images are not stored but the logs are saved during a month.
474
- ## **How to get SUPIR**
475
- You can get SUPIR on HuggingFace by [duplicating this space](https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true) and set GPU.
476
- You can also install SUPIR on your computer following [this tutorial](https://huggingface.co/blog/MonsterMMORPG/supir-sota-image-upscale-better-than-magnific-ai).
477
- You can install _Pinokio_ on your computer and then install _SUPIR_ into it. It should be quite easy if you have an Nvidia GPU.
478
- ## **Terms of use**
479
- By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
480
- ## **License**
481
- The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
482
- """
483
-
484
- # Gradio interface
485
- with gr.Blocks() as interface:
486
- if torch.cuda.device_count() == 0:
487
- with gr.Row():
488
- gr.HTML("""
489
- <p style="background-color: red;"><big><big><big><b>⚠️To use SUPIR, <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR?duplicate=true">duplicate this space</a> and set a GPU with 30 GB VRAM.</b>
490
-
491
- You can't use SUPIR directly here because this space runs on a CPU, which is not enough for SUPIR. Please provide <a href="https://huggingface.co/spaces/Fabrice-TIERCELIN/SUPIR/discussions/new">feedback</a> if you have issues.
492
- </big></big></big></p>
493
- """)
494
- gr.HTML(title_html)
495
-
496
- input_image = gr.Image(label="Input (*.png, *.webp, *.jpeg, *.jpg, *.gif, *.bmp, *.heic)", show_label=True, type="filepath", height=600, elem_id="image-input")
497
- rotation = gr.Radio([["No rotation", 0], ["⤵ Rotate +90°", 90], ["↩ Return 180°", 180], ["⤴ Rotate -90°", -90]], label="Orientation correction", info="Will apply the following rotation before restoring the image; the AI needs a good orientation to understand the content", value=0, interactive=True, visible=False)
498
- with gr.Group():
499
- prompt = gr.Textbox(label="Image description", info="Help the AI understand what the image represents; describe as much as possible, especially the details we can't see on the original image; you can write in any language", value="", placeholder="A 33 years old man, walking, in the street, Santiago, morning, Summer, photorealistic", lines=3)
500
- prompt_hint = gr.HTML("You can use a <a href='"'https://huggingface.co/spaces/badayvedat/LLaVA'"'>LlaVa space</a> to auto-generate the description of your image.")
501
- upscale = gr.Radio([["x1", 1], ["x2", 2], ["x3", 3], ["x4", 4], ["x5", 5], ["x6", 6], ["x7", 7], ["x8", 8], ["x9", 9], ["x10", 10]], label="Upscale factor", info="Resolution x1 to x10", value=2, interactive=True)
502
- output_format = gr.Radio([["As input", "input"], ["*.png", "png"], ["*.webp", "webp"], ["*.jpeg", "jpeg"], ["*.gif", "gif"], ["*.bmp", "bmp"]], label="Image format for result", info="File extention", value="input", interactive=True)
503
- allocation = gr.Slider(label="GPU allocation time (in seconds)", info="lower=May abort run, higher=Quota penalty for next runs", value=179, minimum=59, maximum=320, step=1)
504
-
505
- with gr.Accordion("Pre-denoising (optional)", open=False):
506
- gamma_correction = gr.Slider(label="Gamma Correction", info = "lower=lighter, higher=darker", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
507
- denoise_button = gr.Button(value="Pre-denoise")
508
- denoise_image = gr.Image(label="Denoised image", show_label=True, type="filepath", sources=[], interactive = False, height=600, elem_id="image-s1")
509
- denoise_information = gr.HTML(value="If present, the denoised image will be used for the restoration instead of the input image.", visible=False)
510
-
511
- with gr.Accordion("Advanced options", open=False):
512
- a_prompt = gr.Textbox(label="Additional image description",
513
- info="Completes the main image description",
514
- value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
515
- 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
516
- 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, clothing fabric detailing, '
517
- 'hyper sharpness, perfect without deformations.',
518
- lines=3)
519
- n_prompt = gr.Textbox(label="Negative image description",
520
- info="Disambiguate by listing what the image does NOT represent",
521
- value='painting, oil painting, illustration, drawing, art, sketch, anime, '
522
- 'cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, '
523
- 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
524
- 'deformed, lowres, over-smooth',
525
- lines=3)
526
- edm_steps = gr.Slider(label="Steps", info="lower=faster, higher=more details; too many steps create a checker effect", minimum=1, maximum=200, value=default_setting.edm_steps if torch.cuda.device_count() > 0 else 1, step=1)
527
- num_samples = gr.Slider(label="Num Samples", info="Number of generated results", minimum=1, maximum=4 if not args.use_image_slider else 1
528
- , value=1, step=1)
529
- min_size = gr.Slider(label="Minimum size", info="Minimum height, minimum width of the result", minimum=32, maximum=4096, value=1024, step=32)
530
- downscale = gr.Radio([["/1", 1], ["/2", 2], ["/3", 3], ["/4", 4], ["/5", 5], ["/6", 6], ["/7", 7], ["/8", 8], ["/9", 9], ["/10", 10]], label="Pre-downscale factor", info="Reducing blurred image reduce the process time", value=1, interactive=True)
531
- with gr.Row():
532
- with gr.Column():
533
- model_select = gr.Radio([["💃 Quality (v0-Q)", "v0-Q"], ["🎯 Fidelity (v0-F)", "v0-F"]], label="Model Selection", info="Pretrained model", value="v0-Q",
534
- interactive=True)
535
- with gr.Column():
536
- color_fix_type = gr.Radio([["None", "None"], ["AdaIn (improve as a photo)", "AdaIn"], ["Wavelet (for JPEG artifacts)", "Wavelet"]], label="Color-Fix Type", info="AdaIn=Improve following a style, Wavelet=For JPEG artifacts", value="AdaIn",
537
- interactive=True)
538
- s_cfg = gr.Slider(label="Text Guidance Scale", info="lower=follow the image, higher=follow the prompt", minimum=1.0, maximum=15.0,
539
- value=default_setting.s_cfg_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.1)
540
- s_stage2 = gr.Slider(label="Restoring Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
541
- s_stage1 = gr.Slider(label="Pre-denoising Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
542
- s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
543
- s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
544
- with gr.Row():
545
- with gr.Column():
546
- linear_CFG = gr.Checkbox(label="Linear CFG", value=True)
547
- spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
548
- maximum=9.0, value=default_setting.spt_linear_CFG_Quality if torch.cuda.device_count() > 0 else 1.0, step=0.5)
549
- with gr.Column():
550
- linear_s_stage2 = gr.Checkbox(label="Linear Restoring Guidance", value=False)
551
- spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
552
- maximum=1., value=0., step=0.05)
553
- with gr.Column():
554
- diff_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["fp16 (medium)", "fp16"], ["bf16 (speed)", "bf16"]], label="Diffusion Data Type", value="fp32",
555
- interactive=True)
556
- with gr.Column():
557
- ae_dtype = gr.Radio([["fp32 (precision)", "fp32"], ["bf16 (speed)", "bf16"]], label="Auto-Encoder Data Type", value="fp32",
558
- interactive=True)
559
- randomize_seed = gr.Checkbox(label = "\U0001F3B2 Randomize seed", value = True, info = "If checked, result is always different")
560
- seed = gr.Slider(label="Seed", minimum=0, maximum=max_64_bit_int, step=1, randomize=True)
561
- with gr.Group():
562
- param_setting = gr.Radio(["Quality", "Fidelity"], interactive=True, label="Presetting", value = "Quality")
563
- restart_button = gr.Button(value="Apply presetting")
564
-
565
- with gr.Column():
566
- diffusion_button = gr.Button(value="🚀 Upscale/Restore", variant = "primary", elem_id = "process_button")
567
- reset_btn = gr.Button(value="🧹 Reinit page", variant="stop", elem_id="reset_button", visible = False)
568
-
569
- warning = gr.HTML(value = "<center><big>Your computer must <u>not</u> enter into standby mode.</big><br/>On Chrome, you can force to keep a tab alive in <code>chrome://discards/</code></center>", visible = False)
570
- restore_information = gr.HTML(value = "Restart the process to get another result.", visible = False)
571
- result_slider = ImageSlider(label = 'Comparator', show_label = False, interactive = False, elem_id = "slider1", show_download_button = False)
572
- result_gallery = gr.Gallery(label = 'Downloadable results', show_label = True, interactive = False, elem_id = "gallery1")
573
-
574
- gr.Examples(
575
- examples = [
576
- [
577
- "./Examples/Example1.png",
578
- 0,
579
- None,
580
- "Group of people, walking, happy, in the street, photorealistic, 8k, extremely detailled",
581
- "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
582
- "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
583
- 2,
584
- 1024,
585
- 1,
586
- 8,
587
- 100,
588
- -1,
589
- 1,
590
- 7.5,
591
- False,
592
- 42,
593
- 5,
594
- 1.003,
595
- "AdaIn",
596
- "fp16",
597
- "bf16",
598
- 1.0,
599
- True,
600
- 4,
601
- False,
602
- 0.,
603
- "v0-Q",
604
- "input",
605
- 179
606
- ],
607
- [
608
- "./Examples/Example2.jpeg",
609
- 0,
610
- None,
611
- "La cabeza de un gato atigrado, en una casa, fotorrealista, 8k, extremadamente detallada",
612
- "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
613
- "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
614
- 1,
615
- 1024,
616
- 1,
617
- 1,
618
- 200,
619
- -1,
620
- 1,
621
- 7.5,
622
- False,
623
- 42,
624
- 5,
625
- 1.003,
626
- "Wavelet",
627
- "fp16",
628
- "bf16",
629
- 1.0,
630
- True,
631
- 4,
632
- False,
633
- 0.,
634
- "v0-Q",
635
- "input",
636
- 179
637
- ],
638
- [
639
- "./Examples/Example3.webp",
640
- 0,
641
- None,
642
- "A red apple",
643
- "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
644
- "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
645
- 1,
646
- 1024,
647
- 1,
648
- 1,
649
- 200,
650
- -1,
651
- 1,
652
- 7.5,
653
- False,
654
- 42,
655
- 5,
656
- 1.003,
657
- "Wavelet",
658
- "fp16",
659
- "bf16",
660
- 1.0,
661
- True,
662
- 4,
663
- False,
664
- 0.,
665
- "v0-Q",
666
- "input",
667
- 179
668
- ],
669
- [
670
- "./Examples/Example3.webp",
671
- 0,
672
- None,
673
- "A red marble",
674
- "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
675
- "painting, oil painting, illustration, drawing, art, sketch, anime, cartoon, CG Style, 3D render, unreal engine, blurring, aliasing, pixel, unsharp, weird textures, ugly, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
676
- 1,
677
- 1024,
678
- 1,
679
- 1,
680
- 200,
681
- -1,
682
- 1,
683
- 7.5,
684
- False,
685
- 42,
686
- 5,
687
- 1.003,
688
- "Wavelet",
689
- "fp16",
690
- "bf16",
691
- 1.0,
692
- True,
693
- 4,
694
- False,
695
- 0.,
696
- "v0-Q",
697
- "input",
698
- 179
699
- ],
700
- ],
701
- run_on_click = True,
702
- fn = stage2_process,
703
- inputs = [
704
- input_image,
705
- rotation,
706
- denoise_image,
707
- prompt,
708
- a_prompt,
709
- n_prompt,
710
- num_samples,
711
- min_size,
712
- downscale,
713
- upscale,
714
- edm_steps,
715
- s_stage1,
716
- s_stage2,
717
- s_cfg,
718
- randomize_seed,
719
- seed,
720
- s_churn,
721
- s_noise,
722
- color_fix_type,
723
- diff_dtype,
724
- ae_dtype,
725
- gamma_correction,
726
- linear_CFG,
727
- linear_s_stage2,
728
- spt_linear_CFG,
729
- spt_linear_s_stage2,
730
- model_select,
731
- output_format,
732
- allocation
733
- ],
734
- outputs = [
735
- result_slider,
736
- result_gallery,
737
- restore_information,
738
- reset_btn
739
- ],
740
- cache_examples = False,
741
- )
742
-
743
- with gr.Row():
744
- gr.Markdown(claim_md)
745
-
746
- input_image.upload(fn = check_upload, inputs = [
747
- input_image
748
- ], outputs = [
749
- rotation
750
- ], queue = False, show_progress = False)
751
-
752
- denoise_button.click(fn = check_and_update, inputs = [
753
- input_image
754
- ], outputs = [warning], queue = False, show_progress = False).success(fn = stage1_process, inputs = [
755
- input_image,
756
- gamma_correction,
757
- diff_dtype,
758
- ae_dtype
759
- ], outputs=[
760
- denoise_image,
761
- denoise_information
762
- ])
763
-
764
- diffusion_button.click(fn = update_seed, inputs = [
765
- randomize_seed,
766
- seed
767
- ], outputs = [
768
- seed
769
- ], queue = False, show_progress = False).then(fn = check_and_update, inputs = [
770
- input_image
771
- ], outputs = [warning], queue = False, show_progress = False).success(fn=stage2_process, inputs = [
772
- input_image,
773
- rotation,
774
- denoise_image,
775
- prompt,
776
- a_prompt,
777
- n_prompt,
778
- num_samples,
779
- min_size,
780
- downscale,
781
- upscale,
782
- edm_steps,
783
- s_stage1,
784
- s_stage2,
785
- s_cfg,
786
- randomize_seed,
787
- seed,
788
- s_churn,
789
- s_noise,
790
- color_fix_type,
791
- diff_dtype,
792
- ae_dtype,
793
- gamma_correction,
794
- linear_CFG,
795
- linear_s_stage2,
796
- spt_linear_CFG,
797
- spt_linear_s_stage2,
798
- model_select,
799
- output_format,
800
- allocation
801
- ], outputs = [
802
- result_slider,
803
- result_gallery,
804
- restore_information,
805
- reset_btn
806
- ]).success(fn = log_information, inputs = [
807
- result_gallery
808
- ], outputs = [], queue = False, show_progress = False)
809
-
810
- result_gallery.change(on_select_result, [result_slider, result_gallery], result_slider)
811
- result_gallery.select(on_select_result, [result_slider, result_gallery], result_slider)
812
-
813
- restart_button.click(fn = load_and_reset, inputs = [
814
- param_setting
815
- ], outputs = [
816
- edm_steps,
817
- s_cfg,
818
- s_stage2,
819
- s_stage1,
820
- s_churn,
821
- s_noise,
822
- a_prompt,
823
- n_prompt,
824
- color_fix_type,
825
- linear_CFG,
826
- linear_s_stage2,
827
- spt_linear_CFG,
828
- spt_linear_s_stage2,
829
- model_select
830
- ])
831
-
832
- reset_btn.click(fn = reset, inputs = [], outputs = [
833
- input_image,
834
- rotation,
835
- denoise_image,
836
- prompt,
837
- a_prompt,
838
- n_prompt,
839
- num_samples,
840
- min_size,
841
- downscale,
842
- upscale,
843
- edm_steps,
844
- s_stage1,
845
- s_stage2,
846
- s_cfg,
847
- randomize_seed,
848
- seed,
849
- s_churn,
850
- s_noise,
851
- color_fix_type,
852
- diff_dtype,
853
- ae_dtype,
854
- gamma_correction,
855
- linear_CFG,
856
- linear_s_stage2,
857
- spt_linear_CFG,
858
- spt_linear_s_stage2,
859
- model_select,
860
- output_format,
861
- allocation
862
- ], queue = False, show_progress = False)
863
-
864
- interface.queue(10).launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ import numpy as np
5
+ import random
6
+ import os
7
+ import yaml
8
+ from pathlib import Path
9
+ import imageio
10
+ import tempfile
11
+ from PIL import Image
12
+ from huggingface_hub import hf_hub_download
13
+ import shutil
14
+
15
+ from inference import (
16
+ create_ltx_video_pipeline,
17
+ create_latent_upsampler,
18
+ load_image_to_tensor_with_resize_and_crop,
19
+ seed_everething,
20
+ get_device,
21
+ calculate_padding,
22
+ load_media_file
23
+ )
24
+ from ltx_video.pipelines.pipeline_ltx_video import ConditioningItem, LTXMultiScalePipeline, LTXVideoPipeline
25
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
26
+
27
+ config_file_path = "configs/ltxv-13b-0.9.7-distilled.yaml"
28
+ with open(config_file_path, "r") as file:
29
+ PIPELINE_CONFIG_YAML = yaml.safe_load(file)
30
+
31
+ LTX_REPO = "Lightricks/LTX-Video"
32
+ MAX_IMAGE_SIZE = PIPELINE_CONFIG_YAML.get("max_resolution", 1280)
33
+ MAX_NUM_FRAMES = 257
34
+
35
+ FPS = 30.0
36
+
37
+ # --- Global variables for loaded models ---
38
+ pipeline_instance = None
39
+ latent_upsampler_instance = None
40
+ models_dir = "downloaded_models_gradio_cpu_init"
41
+ Path(models_dir).mkdir(parents=True, exist_ok=True)
42
+
43
+ print("Downloading models (if not present)...")
44
+ distilled_model_actual_path = hf_hub_download(
45
+ repo_id=LTX_REPO,
46
+ filename=PIPELINE_CONFIG_YAML["checkpoint_path"],
47
+ local_dir=models_dir,
48
+ local_dir_use_symlinks=False
49
+ )
50
+ PIPELINE_CONFIG_YAML["checkpoint_path"] = distilled_model_actual_path
51
+ print(f"Distilled model path: {distilled_model_actual_path}")
52
+
53
+ SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"]
54
+ spatial_upscaler_actual_path = hf_hub_download(
55
+ repo_id=LTX_REPO,
56
+ filename=SPATIAL_UPSCALER_FILENAME,
57
+ local_dir=models_dir,
58
+ local_dir_use_symlinks=False
59
+ )
60
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] = spatial_upscaler_actual_path
61
+ print(f"Spatial upscaler model path: {spatial_upscaler_actual_path}")
62
+
63
+ print("Creating LTX Video pipeline on CPU...")
64
+ pipeline_instance = create_ltx_video_pipeline(
65
+ ckpt_path=PIPELINE_CONFIG_YAML["checkpoint_path"],
66
+ precision=PIPELINE_CONFIG_YAML["precision"],
67
+ text_encoder_model_name_or_path=PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"],
68
+ sampler=PIPELINE_CONFIG_YAML["sampler"],
69
+ device="cpu",
70
+ enhance_prompt=False,
71
+ prompt_enhancer_image_caption_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_image_caption_model_name_or_path"],
72
+ prompt_enhancer_llm_model_name_or_path=PIPELINE_CONFIG_YAML["prompt_enhancer_llm_model_name_or_path"],
73
+ )
74
+ print("LTX Video pipeline created on CPU.")
75
+
76
+ if PIPELINE_CONFIG_YAML.get("spatial_upscaler_model_path"):
77
+ print("Creating latent upsampler on CPU...")
78
+ latent_upsampler_instance = create_latent_upsampler(
79
+ PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"],
80
+ device="cpu"
81
+ )
82
+ print("Latent upsampler created on CPU.")
83
+
84
+ target_inference_device = "cuda"
85
+ print(f"Target inference device: {target_inference_device}")
86
+ pipeline_instance.to(target_inference_device)
87
+ if latent_upsampler_instance:
88
+ latent_upsampler_instance.to(target_inference_device)
89
+
90
+
91
+ # --- Helper function for dimension calculation ---
92
+ MIN_DIM_SLIDER = 256 # As defined in the sliders minimum attribute
93
+ TARGET_FIXED_SIDE = 768 # Desired fixed side length as per requirement
94
+
95
+ def calculate_new_dimensions(orig_w, orig_h):
96
+ """
97
+ Calculates new dimensions for height and width sliders based on original media dimensions.
98
+ Ensures one side is TARGET_FIXED_SIDE, the other is scaled proportionally,
99
+ both are multiples of 32, and within [MIN_DIM_SLIDER, MAX_IMAGE_SIZE].
100
+ """
101
+ if orig_w == 0 or orig_h == 0:
102
+ # Default to TARGET_FIXED_SIDE square if original dimensions are invalid
103
+ return int(TARGET_FIXED_SIDE), int(TARGET_FIXED_SIDE)
104
+
105
+ if orig_w >= orig_h: # Landscape or square
106
+ new_h = TARGET_FIXED_SIDE
107
+ aspect_ratio = orig_w / orig_h
108
+ new_w_ideal = new_h * aspect_ratio
109
+
110
+ # Round to nearest multiple of 32
111
+ new_w = round(new_w_ideal / 32) * 32
112
+
113
+ # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
114
+ new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
115
+ # Ensure new_h is also clamped (TARGET_FIXED_SIDE should be within these bounds if configured correctly)
116
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
117
+ else: # Portrait
118
+ new_w = TARGET_FIXED_SIDE
119
+ aspect_ratio = orig_h / orig_w # Use H/W ratio for portrait scaling
120
+ new_h_ideal = new_w * aspect_ratio
121
+
122
+ # Round to nearest multiple of 32
123
+ new_h = round(new_h_ideal / 32) * 32
124
+
125
+ # Clamp to [MIN_DIM_SLIDER, MAX_IMAGE_SIZE]
126
+ new_h = max(MIN_DIM_SLIDER, min(new_h, MAX_IMAGE_SIZE))
127
+ # Ensure new_w is also clamped
128
+ new_w = max(MIN_DIM_SLIDER, min(new_w, MAX_IMAGE_SIZE))
129
+
130
+ return int(new_h), int(new_w)
131
+
132
+ def get_duration(prompt, negative_prompt, input_image_filepath, input_video_filepath,
133
+ height_ui, width_ui, mode,
134
+ duration_ui, # Removed ui_steps
135
+ ui_frames_to_use,
136
+ seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
137
+ progress):
138
+ if duration_ui > 7:
139
+ return 75
140
+ else:
141
+ return 60
142
+
143
+ @spaces.GPU(duration=get_duration)
144
+ def generate(prompt, negative_prompt, input_image_filepath, input_video_filepath,
145
+ height_ui, width_ui, mode,
146
+ duration_ui,
147
+ ui_frames_to_use,
148
+ seed_ui, randomize_seed, ui_guidance_scale, improve_texture_flag,
149
+ progress=gr.Progress(track_tqdm=True)):
150
+
151
+ if randomize_seed:
152
+ seed_ui = random.randint(0, 2**32 - 1)
153
+ seed_everething(int(seed_ui))
154
+
155
+ target_frames_ideal = duration_ui * FPS
156
+ target_frames_rounded = round(target_frames_ideal)
157
+ if target_frames_rounded < 1:
158
+ target_frames_rounded = 1
159
+
160
+ n_val = round((float(target_frames_rounded) - 1.0) / 8.0)
161
+ actual_num_frames = int(n_val * 8 + 1)
162
+
163
+ actual_num_frames = max(9, actual_num_frames)
164
+ actual_num_frames = min(MAX_NUM_FRAMES, actual_num_frames)
165
+
166
+ actual_height = int(height_ui)
167
+ actual_width = int(width_ui)
168
+
169
+ height_padded = ((actual_height - 1) // 32 + 1) * 32
170
+ width_padded = ((actual_width - 1) // 32 + 1) * 32
171
+ num_frames_padded = ((actual_num_frames - 2) // 8 + 1) * 8 + 1
172
+ if num_frames_padded != actual_num_frames:
173
+ print(f"Warning: actual_num_frames ({actual_num_frames}) and num_frames_padded ({num_frames_padded}) differ. Using num_frames_padded for pipeline.")
174
+
175
+ padding_values = calculate_padding(actual_height, actual_width, height_padded, width_padded)
176
+
177
+ call_kwargs = {
178
+ "prompt": prompt,
179
+ "negative_prompt": negative_prompt,
180
+ "height": height_padded,
181
+ "width": width_padded,
182
+ "num_frames": num_frames_padded,
183
+ "frame_rate": int(FPS),
184
+ "generator": torch.Generator(device=target_inference_device).manual_seed(int(seed_ui)),
185
+ "output_type": "pt",
186
+ "conditioning_items": None,
187
+ "media_items": None,
188
+ "decode_timestep": PIPELINE_CONFIG_YAML["decode_timestep"],
189
+ "decode_noise_scale": PIPELINE_CONFIG_YAML["decode_noise_scale"],
190
+ "stochastic_sampling": PIPELINE_CONFIG_YAML["stochastic_sampling"],
191
+ "image_cond_noise_scale": 0.15,
192
+ "is_video": True,
193
+ "vae_per_channel_normalize": True,
194
+ "mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "mixed_precision"),
195
+ "offload_to_cpu": False,
196
+ "enhance_prompt": False,
197
+ }
198
+
199
+ stg_mode_str = PIPELINE_CONFIG_YAML.get("stg_mode", "attention_values")
200
+ if stg_mode_str.lower() in ["stg_av", "attention_values"]:
201
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionValues
202
+ elif stg_mode_str.lower() in ["stg_as", "attention_skip"]:
203
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.AttentionSkip
204
+ elif stg_mode_str.lower() in ["stg_r", "residual"]:
205
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.Residual
206
+ elif stg_mode_str.lower() in ["stg_t", "transformer_block"]:
207
+ call_kwargs["skip_layer_strategy"] = SkipLayerStrategy.TransformerBlock
208
+ else:
209
+ raise ValueError(f"Invalid stg_mode: {stg_mode_str}")
210
+
211
+ if mode == "image-to-video" and input_image_filepath:
212
+ try:
213
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
214
+ input_image_filepath, actual_height, actual_width
215
+ )
216
+ media_tensor = torch.nn.functional.pad(media_tensor, padding_values)
217
+ call_kwargs["conditioning_items"] = [ConditioningItem(media_tensor.to(target_inference_device), 0, 1.0)]
218
+ except Exception as e:
219
+ print(f"Error loading image {input_image_filepath}: {e}")
220
+ raise gr.Error(f"Could not load image: {e}")
221
+ elif mode == "video-to-video" and input_video_filepath:
222
+ try:
223
+ call_kwargs["media_items"] = load_media_file(
224
+ media_path=input_video_filepath,
225
+ height=actual_height,
226
+ width=actual_width,
227
+ max_frames=int(ui_frames_to_use),
228
+ padding=padding_values
229
+ ).to(target_inference_device)
230
+ except Exception as e:
231
+ print(f"Error loading video {input_video_filepath}: {e}")
232
+ raise gr.Error(f"Could not load video: {e}")
233
+
234
+ print(f"Moving models to {target_inference_device} for inference (if not already there)...")
235
+
236
+ active_latent_upsampler = None
237
+ if improve_texture_flag and latent_upsampler_instance:
238
+ active_latent_upsampler = latent_upsampler_instance
239
+
240
+ result_images_tensor = None
241
+ if improve_texture_flag:
242
+ if not active_latent_upsampler:
243
+ raise gr.Error("Spatial upscaler model not loaded or improve_texture not selected, cannot use multi-scale.")
244
+
245
+ multi_scale_pipeline_obj = LTXMultiScalePipeline(pipeline_instance, active_latent_upsampler)
246
+
247
+ first_pass_args = PIPELINE_CONFIG_YAML.get("first_pass", {}).copy()
248
+ first_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
249
+ # num_inference_steps will be derived from len(timesteps) in the pipeline
250
+ first_pass_args.pop("num_inference_steps", None)
251
+
252
+
253
+ second_pass_args = PIPELINE_CONFIG_YAML.get("second_pass", {}).copy()
254
+ second_pass_args["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
255
+ # num_inference_steps will be derived from len(timesteps) in the pipeline
256
+ second_pass_args.pop("num_inference_steps", None)
257
+
258
+ multi_scale_call_kwargs = call_kwargs.copy()
259
+ multi_scale_call_kwargs.update({
260
+ "downscale_factor": PIPELINE_CONFIG_YAML["downscale_factor"],
261
+ "first_pass": first_pass_args,
262
+ "second_pass": second_pass_args,
263
+ })
264
+
265
+ print(f"Calling multi-scale pipeline (eff. HxW: {actual_height}x{actual_width}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
266
+ result_images_tensor = multi_scale_pipeline_obj(**multi_scale_call_kwargs).images
267
+ else:
268
+ single_pass_call_kwargs = call_kwargs.copy()
269
+ first_pass_config_from_yaml = PIPELINE_CONFIG_YAML.get("first_pass", {})
270
+
271
+ single_pass_call_kwargs["timesteps"] = first_pass_config_from_yaml.get("timesteps")
272
+ single_pass_call_kwargs["guidance_scale"] = float(ui_guidance_scale) # UI overrides YAML
273
+ single_pass_call_kwargs["stg_scale"] = first_pass_config_from_yaml.get("stg_scale")
274
+ single_pass_call_kwargs["rescaling_scale"] = first_pass_config_from_yaml.get("rescaling_scale")
275
+ single_pass_call_kwargs["skip_block_list"] = first_pass_config_from_yaml.get("skip_block_list")
276
+
277
+ # Remove keys that might conflict or are not used in single pass / handled by above
278
+ single_pass_call_kwargs.pop("num_inference_steps", None)
279
+ single_pass_call_kwargs.pop("first_pass", None)
280
+ single_pass_call_kwargs.pop("second_pass", None)
281
+ single_pass_call_kwargs.pop("downscale_factor", None)
282
+
283
+ print(f"Calling base pipeline (padded HxW: {height_padded}x{width_padded}, Frames: {actual_num_frames} -> Padded: {num_frames_padded}) on {target_inference_device}")
284
+ result_images_tensor = pipeline_instance(**single_pass_call_kwargs).images
285
+
286
+ if result_images_tensor is None:
287
+ raise gr.Error("Generation failed.")
288
+
289
+ pad_left, pad_right, pad_top, pad_bottom = padding_values
290
+ slice_h_end = -pad_bottom if pad_bottom > 0 else None
291
+ slice_w_end = -pad_right if pad_right > 0 else None
292
+
293
+ result_images_tensor = result_images_tensor[
294
+ :, :, :actual_num_frames, pad_top:slice_h_end, pad_left:slice_w_end
295
+ ]
296
+
297
+ video_np = result_images_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy()
298
+
299
+ video_np = np.clip(video_np, 0, 1)
300
+ video_np = (video_np * 255).astype(np.uint8)
301
+
302
+ temp_dir = tempfile.mkdtemp()
303
+ timestamp = random.randint(10000,99999)
304
+ output_video_path = os.path.join(temp_dir, f"output_{timestamp}.mp4")
305
+
306
+ try:
307
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], macro_block_size=1) as video_writer:
308
+ for frame_idx in range(video_np.shape[0]):
309
+ progress(frame_idx / video_np.shape[0], desc="Saving video")
310
+ video_writer.append_data(video_np[frame_idx])
311
+ except Exception as e:
312
+ print(f"Error saving video with macro_block_size=1: {e}")
313
+ try:
314
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], format='FFMPEG', codec='libx264', quality=8) as video_writer:
315
+ for frame_idx in range(video_np.shape[0]):
316
+ progress(frame_idx / video_np.shape[0], desc="Saving video (fallback ffmpeg)")
317
+ video_writer.append_data(video_np[frame_idx])
318
+ except Exception as e2:
319
+ print(f"Fallback video saving error: {e2}")
320
+ raise gr.Error(f"Failed to save video: {e2}")
321
+
322
+ return output_video_path, seed_ui
323
+
324
+ def update_task_image():
325
+ return "image-to-video"
326
+
327
+ def update_task_text():
328
+ return "text-to-video"
329
+
330
+ def update_task_video():
331
+ return "video-to-video"
332
+
333
+ # --- Gradio UI Definition ---
334
+ css="""
335
+ #col-container {
336
+ margin: 0 auto;
337
+ max-width: 900px;
338
+ }
339
+ """
340
+
341
+ with gr.Blocks(css=css) as demo:
342
+ gr.Markdown("# LTX Video 0.9.7 Distilled")
343
+ gr.Markdown("Fast high quality video generation. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video) [Diffusers](#)")
344
+
345
+ with gr.Row():
346
+ with gr.Column():
347
+ with gr.Tab("image-to-video") as image_tab:
348
+ video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None)
349
+ image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "webcam", "clipboard"])
350
+ i2v_prompt = gr.Textbox(label="Prompt", value="The creature from the image starts to move", lines=3)
351
+ i2v_button = gr.Button("Generate Image-to-Video", variant="primary")
352
+ with gr.Tab("text-to-video") as text_tab:
353
+ image_n_hidden = gr.Textbox(label="image_n", visible=False, value=None)
354
+ video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None)
355
+ t2v_prompt = gr.Textbox(label="Prompt", value="A majestic dragon flying over a medieval castle", lines=3)
356
+ t2v_button = gr.Button("Generate Text-to-Video", variant="primary")
357
+ with gr.Tab("video-to-video", visible=False) as video_tab:
358
+ image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None)
359
+ video_v2v = gr.Video(label="Input Video", sources=["upload", "webcam"]) # type defaults to filepath
360
+ frames_to_use = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames to use for conditioning/transformation. Must be N*8+1.")
361
+ v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3)
362
+ v2v_button = gr.Button("Generate Video-to-Video", variant="primary")
363
+
364
+ duration_input = gr.Slider(
365
+ label="Video Duration (seconds)",
366
+ minimum=0.3,
367
+ maximum=8.5,
368
+ value=2,
369
+ step=0.1,
370
+ info=f"Target video duration (0.3s to 8.5s)"
371
+ )
372
+ improve_texture = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower. Recommended for final output.")
373
+
374
+ with gr.Column():
375
+ output_video = gr.Video(label="Generated Video", interactive=False)
376
+ # gr.DeepLinkButton()
377
+
378
+ with gr.Accordion("Advanced settings", open=False):
379
+ mode = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="task", value="image-to-video", visible=False)
380
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2)
381
+ with gr.Row():
382
+ seed_input = gr.Number(label="Seed", value=42, precision=0, minimum=0, maximum=2**32-1)
383
+ randomize_seed_input = gr.Checkbox(label="Randomize Seed", value=True)
384
+ with gr.Row():
385
+ guidance_scale_input = gr.Slider(label="Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("guidance_scale", 1.0), step=0.1, info="Controls how much the prompt influences the output. Higher values = stronger influence.")
386
+ with gr.Row():
387
+ height_input = gr.Slider(label="Height", value=512, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
388
+ width_input = gr.Slider(label="Width", value=704, step=32, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info="Must be divisible by 32.")
389
+
390
+
391
+ # --- Event handlers for updating dimensions on upload ---
392
+ def handle_image_upload_for_dims(image_filepath, current_h, current_w):
393
+ if not image_filepath: # Image cleared or no image initially
394
+ # Keep current slider values if image is cleared or no input
395
+ return gr.update(value=current_h), gr.update(value=current_w)
396
+ try:
397
+ img = Image.open(image_filepath)
398
+ orig_w, orig_h = img.size
399
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
400
+ return gr.update(value=new_h), gr.update(value=new_w)
401
+ except Exception as e:
402
+ print(f"Error processing image for dimension update: {e}")
403
+ # Keep current slider values on error
404
+ return gr.update(value=current_h), gr.update(value=current_w)
405
+
406
+ def handle_video_upload_for_dims(video_filepath, current_h, current_w):
407
+ if not video_filepath: # Video cleared or no video initially
408
+ return gr.update(value=current_h), gr.update(value=current_w)
409
+ try:
410
+ # Ensure video_filepath is a string for os.path.exists and imageio
411
+ video_filepath_str = str(video_filepath)
412
+ if not os.path.exists(video_filepath_str):
413
+ print(f"Video file path does not exist for dimension update: {video_filepath_str}")
414
+ return gr.update(value=current_h), gr.update(value=current_w)
415
+
416
+ orig_w, orig_h = -1, -1
417
+ with imageio.get_reader(video_filepath_str) as reader:
418
+ meta = reader.get_meta_data()
419
+ if 'size' in meta:
420
+ orig_w, orig_h = meta['size']
421
+ else:
422
+ # Fallback: read first frame if 'size' not in metadata
423
+ try:
424
+ first_frame = reader.get_data(0)
425
+ # Shape is (h, w, c) for frames
426
+ orig_h, orig_w = first_frame.shape[0], first_frame.shape[1]
427
+ except Exception as e_frame:
428
+ print(f"Could not get video size from metadata or first frame: {e_frame}")
429
+ return gr.update(value=current_h), gr.update(value=current_w)
430
+
431
+ if orig_w == -1 or orig_h == -1: # If dimensions couldn't be determined
432
+ print(f"Could not determine dimensions for video: {video_filepath_str}")
433
+ return gr.update(value=current_h), gr.update(value=current_w)
434
+
435
+ new_h, new_w = calculate_new_dimensions(orig_w, orig_h)
436
+ return gr.update(value=new_h), gr.update(value=new_w)
437
+ except Exception as e:
438
+ # Log type of video_filepath for debugging if it's not a path-like string
439
+ print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})")
440
+ return gr.update(value=current_h), gr.update(value=current_w)
441
+
442
+
443
+ image_i2v.upload(
444
+ fn=handle_image_upload_for_dims,
445
+ inputs=[image_i2v, height_input, width_input],
446
+ outputs=[height_input, width_input]
447
+ )
448
+ video_v2v.upload(
449
+ fn=handle_video_upload_for_dims,
450
+ inputs=[video_v2v, height_input, width_input],
451
+ outputs=[height_input, width_input]
452
+ )
453
+
454
+ image_tab.select(
455
+ fn=update_task_image,
456
+ outputs=[mode]
457
+ )
458
+ text_tab.select(
459
+ fn=update_task_text,
460
+ outputs=[mode]
461
+ )
462
+
463
+ t2v_inputs = [t2v_prompt, negative_prompt_input, image_n_hidden, video_n_hidden,
464
+ height_input, width_input, mode,
465
+ duration_input, frames_to_use,
466
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
467
+
468
+ i2v_inputs = [i2v_prompt, negative_prompt_input, image_i2v, video_i_hidden,
469
+ height_input, width_input, mode,
470
+ duration_input, frames_to_use,
471
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
472
+
473
+ v2v_inputs = [v2v_prompt, negative_prompt_input, image_v_hidden, video_v2v,
474
+ height_input, width_input, mode,
475
+ duration_input, frames_to_use,
476
+ seed_input, randomize_seed_input, guidance_scale_input, improve_texture]
477
+
478
+ t2v_button.click(fn=generate, inputs=t2v_inputs, outputs=[output_video, seed_input], api_name="text_to_video")
479
+ i2v_button.click(fn=generate, inputs=i2v_inputs, outputs=[output_video, seed_input], api_name="image_to_video")
480
+ v2v_button.click(fn=generate, inputs=v2v_inputs, outputs=[output_video, seed_input], api_name="video_to_video")
481
+
482
+ if __name__ == "__main__":
483
+ if os.path.exists(models_dir) and os.path.isdir(models_dir):
484
+ print(f"Model directory: {Path(models_dir).resolve()}")
485
+
486
+ demo.queue().launch(debug=True, share=False, mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py ADDED
@@ -0,0 +1,774 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from diffusers.utils import logging
7
+ from typing import Optional, List, Union
8
+ import yaml
9
+
10
+ import imageio
11
+ import json
12
+ import numpy as np
13
+ import torch
14
+ import cv2
15
+ from safetensors import safe_open
16
+ from PIL import Image
17
+ from transformers import (
18
+ T5EncoderModel,
19
+ T5Tokenizer,
20
+ AutoModelForCausalLM,
21
+ AutoProcessor,
22
+ AutoTokenizer,
23
+ )
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
27
+ CausalVideoAutoencoder,
28
+ )
29
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
30
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
31
+ from ltx_video.pipelines.pipeline_ltx_video import (
32
+ ConditioningItem,
33
+ LTXVideoPipeline,
34
+ LTXMultiScalePipeline,
35
+ )
36
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
37
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
38
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
39
+ import ltx_video.pipelines.crf_compressor as crf_compressor
40
+
41
+ MAX_HEIGHT = 720
42
+ MAX_WIDTH = 1280
43
+ MAX_NUM_FRAMES = 257
44
+
45
+ logger = logging.get_logger("LTX-Video")
46
+
47
+
48
+ def get_total_gpu_memory():
49
+ if torch.cuda.is_available():
50
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
51
+ return total_memory
52
+ return 0
53
+
54
+
55
+ def get_device():
56
+ if torch.cuda.is_available():
57
+ return "cuda"
58
+ elif torch.backends.mps.is_available():
59
+ return "mps"
60
+ return "cpu"
61
+
62
+
63
+ def load_image_to_tensor_with_resize_and_crop(
64
+ image_input: Union[str, Image.Image],
65
+ target_height: int = 512,
66
+ target_width: int = 768,
67
+ just_crop: bool = False,
68
+ ) -> torch.Tensor:
69
+ """Load and process an image into a tensor.
70
+
71
+ Args:
72
+ image_input: Either a file path (str) or a PIL Image object
73
+ target_height: Desired height of output tensor
74
+ target_width: Desired width of output tensor
75
+ just_crop: If True, only crop the image to the target size without resizing
76
+ """
77
+ if isinstance(image_input, str):
78
+ image = Image.open(image_input).convert("RGB")
79
+ elif isinstance(image_input, Image.Image):
80
+ image = image_input
81
+ else:
82
+ raise ValueError("image_input must be either a file path or a PIL Image object")
83
+
84
+ input_width, input_height = image.size
85
+ aspect_ratio_target = target_width / target_height
86
+ aspect_ratio_frame = input_width / input_height
87
+ if aspect_ratio_frame > aspect_ratio_target:
88
+ new_width = int(input_height * aspect_ratio_target)
89
+ new_height = input_height
90
+ x_start = (input_width - new_width) // 2
91
+ y_start = 0
92
+ else:
93
+ new_width = input_width
94
+ new_height = int(input_width / aspect_ratio_target)
95
+ x_start = 0
96
+ y_start = (input_height - new_height) // 2
97
+
98
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
99
+ if not just_crop:
100
+ image = image.resize((target_width, target_height))
101
+
102
+ image = np.array(image)
103
+ image = cv2.GaussianBlur(image, (3, 3), 0)
104
+ frame_tensor = torch.from_numpy(image).float()
105
+ frame_tensor = crf_compressor.compress(frame_tensor / 255.0) * 255.0
106
+ frame_tensor = frame_tensor.permute(2, 0, 1)
107
+ frame_tensor = (frame_tensor / 127.5) - 1.0
108
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
109
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
110
+
111
+
112
+ def calculate_padding(
113
+ source_height: int, source_width: int, target_height: int, target_width: int
114
+ ) -> tuple[int, int, int, int]:
115
+
116
+ # Calculate total padding needed
117
+ pad_height = target_height - source_height
118
+ pad_width = target_width - source_width
119
+
120
+ # Calculate padding for each side
121
+ pad_top = pad_height // 2
122
+ pad_bottom = pad_height - pad_top # Handles odd padding
123
+ pad_left = pad_width // 2
124
+ pad_right = pad_width - pad_left # Handles odd padding
125
+
126
+ # Return padded tensor
127
+ # Padding format is (left, right, top, bottom)
128
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
129
+ return padding
130
+
131
+
132
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
133
+ # Remove non-letters and convert to lowercase
134
+ clean_text = "".join(
135
+ char.lower() for char in text if char.isalpha() or char.isspace()
136
+ )
137
+
138
+ # Split into words
139
+ words = clean_text.split()
140
+
141
+ # Build result string keeping track of length
142
+ result = []
143
+ current_length = 0
144
+
145
+ for word in words:
146
+ # Add word length plus 1 for underscore (except for first word)
147
+ new_length = current_length + len(word)
148
+
149
+ if new_length <= max_len:
150
+ result.append(word)
151
+ current_length += len(word)
152
+ else:
153
+ break
154
+
155
+ return "-".join(result)
156
+
157
+
158
+ # Generate output video name
159
+ def get_unique_filename(
160
+ base: str,
161
+ ext: str,
162
+ prompt: str,
163
+ seed: int,
164
+ resolution: tuple[int, int, int],
165
+ dir: Path,
166
+ endswith=None,
167
+ index_range=1000,
168
+ ) -> Path:
169
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
170
+ for i in range(index_range):
171
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
172
+ if not os.path.exists(filename):
173
+ return filename
174
+ raise FileExistsError(
175
+ f"Could not find a unique filename after {index_range} attempts."
176
+ )
177
+
178
+
179
+ def seed_everething(seed: int):
180
+ random.seed(seed)
181
+ np.random.seed(seed)
182
+ torch.manual_seed(seed)
183
+ if torch.cuda.is_available():
184
+ torch.cuda.manual_seed(seed)
185
+ if torch.backends.mps.is_available():
186
+ torch.mps.manual_seed(seed)
187
+
188
+
189
+ def main():
190
+ parser = argparse.ArgumentParser(
191
+ description="Load models from separate directories and run the pipeline."
192
+ )
193
+
194
+ # Directories
195
+ parser.add_argument(
196
+ "--output_path",
197
+ type=str,
198
+ default=None,
199
+ help="Path to the folder to save output video, if None will save in outputs/ directory.",
200
+ )
201
+ parser.add_argument("--seed", type=int, default="171198")
202
+
203
+ # Pipeline parameters
204
+ parser.add_argument(
205
+ "--num_images_per_prompt",
206
+ type=int,
207
+ default=1,
208
+ help="Number of images per prompt",
209
+ )
210
+ parser.add_argument(
211
+ "--image_cond_noise_scale",
212
+ type=float,
213
+ default=0.15,
214
+ help="Amount of noise to add to the conditioned image",
215
+ )
216
+ parser.add_argument(
217
+ "--height",
218
+ type=int,
219
+ default=704,
220
+ help="Height of the output video frames. Optional if an input image provided.",
221
+ )
222
+ parser.add_argument(
223
+ "--width",
224
+ type=int,
225
+ default=1216,
226
+ help="Width of the output video frames. If None will infer from input image.",
227
+ )
228
+ parser.add_argument(
229
+ "--num_frames",
230
+ type=int,
231
+ default=121,
232
+ help="Number of frames to generate in the output video",
233
+ )
234
+ parser.add_argument(
235
+ "--frame_rate", type=int, default=30, help="Frame rate for the output video"
236
+ )
237
+ parser.add_argument(
238
+ "--device",
239
+ default=None,
240
+ help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
241
+ )
242
+ parser.add_argument(
243
+ "--pipeline_config",
244
+ type=str,
245
+ default="configs/ltxv-13b-0.9.7-dev.yaml",
246
+ help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
247
+ )
248
+
249
+ # Prompts
250
+ parser.add_argument(
251
+ "--prompt",
252
+ type=str,
253
+ help="Text prompt to guide generation",
254
+ )
255
+ parser.add_argument(
256
+ "--negative_prompt",
257
+ type=str,
258
+ default="worst quality, inconsistent motion, blurry, jittery, distorted",
259
+ help="Negative prompt for undesired features",
260
+ )
261
+
262
+ parser.add_argument(
263
+ "--offload_to_cpu",
264
+ action="store_true",
265
+ help="Offloading unnecessary computations to CPU.",
266
+ )
267
+
268
+ # video-to-video arguments:
269
+ parser.add_argument(
270
+ "--input_media_path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
274
+ )
275
+
276
+ # Conditioning arguments
277
+ parser.add_argument(
278
+ "--conditioning_media_paths",
279
+ type=str,
280
+ nargs="*",
281
+ help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
282
+ )
283
+ parser.add_argument(
284
+ "--conditioning_strengths",
285
+ type=float,
286
+ nargs="*",
287
+ help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
288
+ )
289
+ parser.add_argument(
290
+ "--conditioning_start_frames",
291
+ type=int,
292
+ nargs="*",
293
+ help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
294
+ )
295
+
296
+ args = parser.parse_args()
297
+ logger.warning(f"Running generation with arguments: {args}")
298
+ infer(**vars(args))
299
+
300
+
301
+ def create_ltx_video_pipeline(
302
+ ckpt_path: str,
303
+ precision: str,
304
+ text_encoder_model_name_or_path: str,
305
+ sampler: Optional[str] = None,
306
+ device: Optional[str] = None,
307
+ enhance_prompt: bool = False,
308
+ prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
309
+ prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
310
+ ) -> LTXVideoPipeline:
311
+ ckpt_path = Path(ckpt_path)
312
+ assert os.path.exists(
313
+ ckpt_path
314
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
315
+
316
+ with safe_open(ckpt_path, framework="pt") as f:
317
+ metadata = f.metadata()
318
+ config_str = metadata.get("config")
319
+ configs = json.loads(config_str)
320
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
321
+
322
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
323
+ transformer = Transformer3DModel.from_pretrained(ckpt_path)
324
+
325
+ # Use constructor if sampler is specified, otherwise use from_pretrained
326
+ if sampler == "from_checkpoint" or not sampler:
327
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
328
+ else:
329
+ scheduler = RectifiedFlowScheduler(
330
+ sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
331
+ )
332
+
333
+ text_encoder = T5EncoderModel.from_pretrained(
334
+ text_encoder_model_name_or_path, subfolder="text_encoder"
335
+ )
336
+ patchifier = SymmetricPatchifier(patch_size=1)
337
+ tokenizer = T5Tokenizer.from_pretrained(
338
+ text_encoder_model_name_or_path, subfolder="tokenizer"
339
+ )
340
+
341
+ transformer = transformer.to(device)
342
+ vae = vae.to(device)
343
+ text_encoder = text_encoder.to(device)
344
+
345
+ if enhance_prompt:
346
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
347
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
348
+ )
349
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
350
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
351
+ )
352
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
353
+ prompt_enhancer_llm_model_name_or_path,
354
+ torch_dtype="bfloat16",
355
+ )
356
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
357
+ prompt_enhancer_llm_model_name_or_path,
358
+ )
359
+ else:
360
+ prompt_enhancer_image_caption_model = None
361
+ prompt_enhancer_image_caption_processor = None
362
+ prompt_enhancer_llm_model = None
363
+ prompt_enhancer_llm_tokenizer = None
364
+
365
+ vae = vae.to(torch.bfloat16)
366
+ if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
367
+ transformer = transformer.to(torch.bfloat16)
368
+ text_encoder = text_encoder.to(torch.bfloat16)
369
+
370
+ # Use submodels for the pipeline
371
+ submodel_dict = {
372
+ "transformer": transformer,
373
+ "patchifier": patchifier,
374
+ "text_encoder": text_encoder,
375
+ "tokenizer": tokenizer,
376
+ "scheduler": scheduler,
377
+ "vae": vae,
378
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
379
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
380
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
381
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
382
+ "allowed_inference_steps": allowed_inference_steps,
383
+ }
384
+
385
+ pipeline = LTXVideoPipeline(**submodel_dict)
386
+ pipeline = pipeline.to(device)
387
+ return pipeline
388
+
389
+
390
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
391
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
392
+ latent_upsampler.to(device)
393
+ latent_upsampler.eval()
394
+ return latent_upsampler
395
+
396
+
397
+ def infer(
398
+ output_path: Optional[str],
399
+ seed: int,
400
+ pipeline_config: str,
401
+ image_cond_noise_scale: float,
402
+ height: Optional[int],
403
+ width: Optional[int],
404
+ num_frames: int,
405
+ frame_rate: int,
406
+ prompt: str,
407
+ negative_prompt: str,
408
+ offload_to_cpu: bool,
409
+ input_media_path: Optional[str] = None,
410
+ conditioning_media_paths: Optional[List[str]] = None,
411
+ conditioning_strengths: Optional[List[float]] = None,
412
+ conditioning_start_frames: Optional[List[int]] = None,
413
+ device: Optional[str] = None,
414
+ **kwargs,
415
+ ):
416
+ # check if pipeline_config is a file
417
+ if not os.path.isfile(pipeline_config):
418
+ raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
419
+ with open(pipeline_config, "r") as f:
420
+ pipeline_config = yaml.safe_load(f)
421
+
422
+ models_dir = "MODEL_DIR"
423
+
424
+ ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
425
+ if not os.path.isfile(ltxv_model_name_or_path):
426
+ ltxv_model_path = hf_hub_download(
427
+ repo_id="Lightricks/LTX-Video",
428
+ filename=ltxv_model_name_or_path,
429
+ local_dir=models_dir,
430
+ repo_type="model",
431
+ )
432
+ else:
433
+ ltxv_model_path = ltxv_model_name_or_path
434
+
435
+ spatial_upscaler_model_name_or_path = pipeline_config.get(
436
+ "spatial_upscaler_model_path"
437
+ )
438
+ if spatial_upscaler_model_name_or_path and not os.path.isfile(
439
+ spatial_upscaler_model_name_or_path
440
+ ):
441
+ spatial_upscaler_model_path = hf_hub_download(
442
+ repo_id="Lightricks/LTX-Video",
443
+ filename=spatial_upscaler_model_name_or_path,
444
+ local_dir=models_dir,
445
+ repo_type="model",
446
+ )
447
+ else:
448
+ spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
449
+
450
+ if kwargs.get("input_image_path", None):
451
+ logger.warning(
452
+ "Please use conditioning_media_paths instead of input_image_path."
453
+ )
454
+ assert not conditioning_media_paths and not conditioning_start_frames
455
+ conditioning_media_paths = [kwargs["input_image_path"]]
456
+ conditioning_start_frames = [0]
457
+
458
+ # Validate conditioning arguments
459
+ if conditioning_media_paths:
460
+ # Use default strengths of 1.0
461
+ if not conditioning_strengths:
462
+ conditioning_strengths = [1.0] * len(conditioning_media_paths)
463
+ if not conditioning_start_frames:
464
+ raise ValueError(
465
+ "If `conditioning_media_paths` is provided, "
466
+ "`conditioning_start_frames` must also be provided"
467
+ )
468
+ if len(conditioning_media_paths) != len(conditioning_strengths) or len(
469
+ conditioning_media_paths
470
+ ) != len(conditioning_start_frames):
471
+ raise ValueError(
472
+ "`conditioning_media_paths`, `conditioning_strengths`, "
473
+ "and `conditioning_start_frames` must have the same length"
474
+ )
475
+ if any(s < 0 or s > 1 for s in conditioning_strengths):
476
+ raise ValueError("All conditioning strengths must be between 0 and 1")
477
+ if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
478
+ raise ValueError(
479
+ f"All conditioning start frames must be between 0 and {num_frames-1}"
480
+ )
481
+
482
+ seed_everething(seed)
483
+ if offload_to_cpu and not torch.cuda.is_available():
484
+ logger.warning(
485
+ "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
486
+ )
487
+ offload_to_cpu = False
488
+ else:
489
+ offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
490
+
491
+ output_dir = (
492
+ Path(output_path)
493
+ if output_path
494
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
495
+ )
496
+ output_dir.mkdir(parents=True, exist_ok=True)
497
+
498
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
499
+ height_padded = ((height - 1) // 32 + 1) * 32
500
+ width_padded = ((width - 1) // 32 + 1) * 32
501
+ num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
502
+
503
+ padding = calculate_padding(height, width, height_padded, width_padded)
504
+
505
+ logger.warning(
506
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
507
+ )
508
+
509
+ prompt_enhancement_words_threshold = pipeline_config[
510
+ "prompt_enhancement_words_threshold"
511
+ ]
512
+
513
+ prompt_word_count = len(prompt.split())
514
+ enhance_prompt = (
515
+ prompt_enhancement_words_threshold > 0
516
+ and prompt_word_count < prompt_enhancement_words_threshold
517
+ )
518
+
519
+ if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
520
+ logger.info(
521
+ f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
522
+ )
523
+
524
+ precision = pipeline_config["precision"]
525
+ text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
526
+ sampler = pipeline_config["sampler"]
527
+ prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
528
+ "prompt_enhancer_image_caption_model_name_or_path"
529
+ ]
530
+ prompt_enhancer_llm_model_name_or_path = pipeline_config[
531
+ "prompt_enhancer_llm_model_name_or_path"
532
+ ]
533
+
534
+ pipeline = create_ltx_video_pipeline(
535
+ ckpt_path=ltxv_model_path,
536
+ precision=precision,
537
+ text_encoder_model_name_or_path=text_encoder_model_name_or_path,
538
+ sampler=sampler,
539
+ device=kwargs.get("device", get_device()),
540
+ enhance_prompt=enhance_prompt,
541
+ prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
542
+ prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
543
+ )
544
+
545
+ if pipeline_config.get("pipeline_type", None) == "multi-scale":
546
+ if not spatial_upscaler_model_path:
547
+ raise ValueError(
548
+ "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
549
+ )
550
+ latent_upsampler = create_latent_upsampler(
551
+ spatial_upscaler_model_path, pipeline.device
552
+ )
553
+ pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
554
+
555
+ media_item = None
556
+ if input_media_path:
557
+ media_item = load_media_file(
558
+ media_path=input_media_path,
559
+ height=height,
560
+ width=width,
561
+ max_frames=num_frames_padded,
562
+ padding=padding,
563
+ )
564
+
565
+ conditioning_items = (
566
+ prepare_conditioning(
567
+ conditioning_media_paths=conditioning_media_paths,
568
+ conditioning_strengths=conditioning_strengths,
569
+ conditioning_start_frames=conditioning_start_frames,
570
+ height=height,
571
+ width=width,
572
+ num_frames=num_frames,
573
+ padding=padding,
574
+ pipeline=pipeline,
575
+ )
576
+ if conditioning_media_paths
577
+ else None
578
+ )
579
+
580
+ stg_mode = pipeline_config.get("stg_mode", "attention_values")
581
+ del pipeline_config["stg_mode"]
582
+ if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
583
+ skip_layer_strategy = SkipLayerStrategy.AttentionValues
584
+ elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
585
+ skip_layer_strategy = SkipLayerStrategy.AttentionSkip
586
+ elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
587
+ skip_layer_strategy = SkipLayerStrategy.Residual
588
+ elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
589
+ skip_layer_strategy = SkipLayerStrategy.TransformerBlock
590
+ else:
591
+ raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
592
+
593
+ # Prepare input for the pipeline
594
+ sample = {
595
+ "prompt": prompt,
596
+ "prompt_attention_mask": None,
597
+ "negative_prompt": negative_prompt,
598
+ "negative_prompt_attention_mask": None,
599
+ }
600
+
601
+ device = device or get_device()
602
+ generator = torch.Generator(device=device).manual_seed(seed)
603
+
604
+ images = pipeline(
605
+ **pipeline_config,
606
+ skip_layer_strategy=skip_layer_strategy,
607
+ generator=generator,
608
+ output_type="pt",
609
+ callback_on_step_end=None,
610
+ height=height_padded,
611
+ width=width_padded,
612
+ num_frames=num_frames_padded,
613
+ frame_rate=frame_rate,
614
+ **sample,
615
+ media_items=media_item,
616
+ conditioning_items=conditioning_items,
617
+ is_video=True,
618
+ vae_per_channel_normalize=True,
619
+ image_cond_noise_scale=image_cond_noise_scale,
620
+ mixed_precision=(precision == "mixed_precision"),
621
+ offload_to_cpu=offload_to_cpu,
622
+ device=device,
623
+ enhance_prompt=enhance_prompt,
624
+ ).images
625
+
626
+ # Crop the padded images to the desired resolution and number of frames
627
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
628
+ pad_bottom = -pad_bottom
629
+ pad_right = -pad_right
630
+ if pad_bottom == 0:
631
+ pad_bottom = images.shape[3]
632
+ if pad_right == 0:
633
+ pad_right = images.shape[4]
634
+ images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
635
+
636
+ for i in range(images.shape[0]):
637
+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
638
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
639
+ # Unnormalizing images to [0, 255] range
640
+ video_np = (video_np * 255).astype(np.uint8)
641
+ fps = frame_rate
642
+ height, width = video_np.shape[1:3]
643
+ # In case a single image is generated
644
+ if video_np.shape[0] == 1:
645
+ output_filename = get_unique_filename(
646
+ f"image_output_{i}",
647
+ ".png",
648
+ prompt=prompt,
649
+ seed=seed,
650
+ resolution=(height, width, num_frames),
651
+ dir=output_dir,
652
+ )
653
+ imageio.imwrite(output_filename, video_np[0])
654
+ else:
655
+ output_filename = get_unique_filename(
656
+ f"video_output_{i}",
657
+ ".mp4",
658
+ prompt=prompt,
659
+ seed=seed,
660
+ resolution=(height, width, num_frames),
661
+ dir=output_dir,
662
+ )
663
+
664
+ # Write video
665
+ with imageio.get_writer(output_filename, fps=fps) as video:
666
+ for frame in video_np:
667
+ video.append_data(frame)
668
+
669
+ logger.warning(f"Output saved to {output_filename}")
670
+
671
+
672
+ def prepare_conditioning(
673
+ conditioning_media_paths: List[str],
674
+ conditioning_strengths: List[float],
675
+ conditioning_start_frames: List[int],
676
+ height: int,
677
+ width: int,
678
+ num_frames: int,
679
+ padding: tuple[int, int, int, int],
680
+ pipeline: LTXVideoPipeline,
681
+ ) -> Optional[List[ConditioningItem]]:
682
+ """Prepare conditioning items based on input media paths and their parameters.
683
+
684
+ Args:
685
+ conditioning_media_paths: List of paths to conditioning media (images or videos)
686
+ conditioning_strengths: List of conditioning strengths for each media item
687
+ conditioning_start_frames: List of frame indices where each item should be applied
688
+ height: Height of the output frames
689
+ width: Width of the output frames
690
+ num_frames: Number of frames in the output video
691
+ padding: Padding to apply to the frames
692
+ pipeline: LTXVideoPipeline object used for condition video trimming
693
+
694
+ Returns:
695
+ A list of ConditioningItem objects.
696
+ """
697
+ conditioning_items = []
698
+ for path, strength, start_frame in zip(
699
+ conditioning_media_paths, conditioning_strengths, conditioning_start_frames
700
+ ):
701
+ num_input_frames = orig_num_input_frames = get_media_num_frames(path)
702
+ if hasattr(pipeline, "trim_conditioning_sequence") and callable(
703
+ getattr(pipeline, "trim_conditioning_sequence")
704
+ ):
705
+ num_input_frames = pipeline.trim_conditioning_sequence(
706
+ start_frame, orig_num_input_frames, num_frames
707
+ )
708
+ if num_input_frames < orig_num_input_frames:
709
+ logger.warning(
710
+ f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
711
+ )
712
+
713
+ media_tensor = load_media_file(
714
+ media_path=path,
715
+ height=height,
716
+ width=width,
717
+ max_frames=num_input_frames,
718
+ padding=padding,
719
+ just_crop=True,
720
+ )
721
+ conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
722
+ return conditioning_items
723
+
724
+
725
+ def get_media_num_frames(media_path: str) -> int:
726
+ is_video = any(
727
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
728
+ )
729
+ num_frames = 1
730
+ if is_video:
731
+ reader = imageio.get_reader(media_path)
732
+ num_frames = reader.count_frames()
733
+ reader.close()
734
+ return num_frames
735
+
736
+
737
+ def load_media_file(
738
+ media_path: str,
739
+ height: int,
740
+ width: int,
741
+ max_frames: int,
742
+ padding: tuple[int, int, int, int],
743
+ just_crop: bool = False,
744
+ ) -> torch.Tensor:
745
+ is_video = any(
746
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
747
+ )
748
+ if is_video:
749
+ reader = imageio.get_reader(media_path)
750
+ num_input_frames = min(reader.count_frames(), max_frames)
751
+
752
+ # Read and preprocess the relevant frames from the video file.
753
+ frames = []
754
+ for i in range(num_input_frames):
755
+ frame = Image.fromarray(reader.get_data(i))
756
+ frame_tensor = load_image_to_tensor_with_resize_and_crop(
757
+ frame, height, width, just_crop=just_crop
758
+ )
759
+ frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
760
+ frames.append(frame_tensor)
761
+ reader.close()
762
+
763
+ # Stack frames along the temporal dimension
764
+ media_tensor = torch.cat(frames, dim=2)
765
+ else: # Input image
766
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
767
+ media_path, height, width, just_crop=just_crop
768
+ )
769
+ media_tensor = torch.nn.functional.pad(media_tensor, padding)
770
+ return media_tensor
771
+
772
+
773
+ if __name__ == "__main__":
774
+ main()
requirements.txt CHANGED
@@ -1,48 +1,15 @@
1
- pydantic==2.10.6
2
- fastapi==0.115.8
3
- gradio_imageslider==0.0.20
4
- gradio_client==1.7.0
5
- numpy==1.26.4
6
- requests==2.32.3
7
- sentencepiece==0.2.0
8
- tokenizers==0.19.1
9
- torchvision==0.18.1
10
- uvicorn==0.30.1
11
- wandb==0.17.4
12
- httpx==0.27.0
13
- transformers==4.42.4
14
- accelerate==0.32.1
15
- scikit-learn==1.5.1
16
- einops==0.8.0
17
- einops-exts==0.0.4
18
- timm==1.0.7
19
- openai-clip==1.0.1
20
- fsspec==2024.6.1
21
- kornia==0.7.3
22
- matplotlib==3.9.1
23
- ninja==1.11.1.1
24
- omegaconf==2.3.0
25
- opencv-python==4.10.0.84
26
- pandas==2.2.2
27
- pillow==10.4.0
28
- pytorch-lightning==2.3.3
29
- PyYAML==6.0.1
30
- scipy==1.14.0
31
- tqdm==4.66.4
32
- triton==2.3.1
33
- urllib3==2.2.2
34
- webdataset==0.2.86
35
- xformers==0.0.27
36
- facexlib==0.3.0
37
- k-diffusion==0.1.1.post1
38
- diffusers==0.30.0
39
- pillow-heif==0.18.0
40
-
41
- open-clip-torch==2.24.0
42
-
43
- torchaudio
44
- easydict==1.13
45
- fairscale==0.4.13
46
- torchsde==0.2.6
47
- huggingface_hub==0.23.3
48
- gradio
 
1
+ accelerate
2
+ transformers
3
+ sentencepiece
4
+ pillow
5
+ numpy
6
+ torchvision
7
+ huggingface_hub
8
+ spaces
9
+ opencv-python
10
+ imageio
11
+ imageio-ffmpeg
12
+ einops
13
+ timm
14
+ av
15
+ git+https://github.com/huggingface/diffusers.git@main