englissi commited on
Commit
1bc9c1f
Β·
verified Β·
1 Parent(s): b2d14ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -60
app.py CHANGED
@@ -1,39 +1,60 @@
1
- import os, torch, tempfile
 
 
2
  import gradio as gr
 
3
  from diffusers import LTXPipeline, AutoModel
4
  from diffusers.hooks import apply_group_offloading
5
  from diffusers.utils import export_to_video
6
 
7
- # --------- λͺ¨λΈ λ‘œλ“œ ν•¨μˆ˜ ---------
8
- def load_pipeline(device="cuda"):
9
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float16
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  transformer = AutoModel.from_pretrained(
12
  "Lightricks/LTX-Video",
13
  subfolder="transformer",
14
  torch_dtype=dtype,
15
- trust_remote_code=True, # μ€‘μš”: Placeholder λ°©μ§€
16
- variant="bf16" if dtype==torch.bfloat16 else None
17
  )
18
 
19
- # fp8 layerwise casting (ν™˜κ²½ λ―Έμ§€μ›μ‹œ λ¬΄μ‹œ)
 
20
  try:
21
  transformer.enable_layerwise_casting(
22
- storage_dtype=torch.float8_e4m3fn, compute_dtype=dtype
 
23
  )
24
- fp8 = True
25
  except Exception:
26
- fp8 = False
27
 
 
28
  pipe = LTXPipeline.from_pretrained(
29
  "Lightricks/LTX-Video",
30
  transformer=transformer,
31
  torch_dtype=dtype,
32
  trust_remote_code=True,
33
- variant="bf16" if dtype==torch.bfloat16 else None
34
  ).to(device)
35
 
36
- # group offloading (지원 μ•ˆλ˜λ©΄ λ¬΄μ‹œ)
 
37
  try:
38
  onload_device = torch.device(device)
39
  offload_device = torch.device("cpu")
@@ -43,77 +64,146 @@ def load_pipeline(device="cuda"):
43
  offload_type="leaf_level",
44
  use_stream=True
45
  )
46
- apply_group_offloading(pipe.text_encoder, onload_device=onload_device,
47
- offload_type="block_level", num_blocks_per_group=2)
48
- apply_group_offloading(pipe.vae, onload_device=onload_device,
49
- offload_type="leaf_level")
50
- offload = True
 
 
 
 
 
 
 
51
  except Exception:
52
- offload = False
53
 
54
- return pipe, fp8, offload
55
 
56
- PIPE, FP8_OK, OFFLOAD_OK = load_pipeline("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
- # --------- λΉ„λ””μ˜€ 생성 ---------
59
- def generate(prompt, negative_prompt,
60
- width, height, num_frames, fps,
61
- decode_timestep, decode_noise_scale,
62
- steps, seed):
63
 
64
- g = None
65
- if seed is not None and seed >= 0:
66
- g = torch.Generator(device="cuda" if torch.cuda.is_available() else "cpu").manual_seed(int(seed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with torch.inference_mode():
69
- result = PIPE(
70
- prompt=prompt,
71
- negative_prompt=negative_prompt or None,
72
- width=width,
73
- height=height,
74
- num_frames=num_frames,
75
- fps=fps,
76
- decode_timestep=decode_timestep,
77
- decode_noise_scale=decode_noise_scale,
78
- num_inference_steps=steps,
79
  generator=g
80
  )
81
- frames = result.frames[0]
82
 
 
 
 
 
83
  tmpdir = tempfile.mkdtemp()
84
  save_path = os.path.join(tmpdir, "output.mp4")
85
- export_to_video(frames, save_path, fps=fps)
86
- return save_path, f"FP8: {'ON' if FP8_OK else 'OFF'} | Offloading: {'ON' if OFFLOAD_OK else 'OFF'}"
87
 
88
- # --------- Gradio UI ---------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  with gr.Blocks(title="LTX-Video Gradio") as demo:
90
- gr.Markdown("## 🎬 LTX-Video Gradio Demo")
91
 
92
  with gr.Row():
93
- prompt_in = gr.Textbox(label="Prompt", lines=6, value="A cinematic close-up of a smiling woman under warm sunset light.")
94
- neg_in = gr.Textbox(label="Negative Prompt", lines=4, value="worst quality, inconsistent motion, blurry, jittery, distorted")
 
 
 
 
 
 
 
 
95
 
96
  with gr.Row():
97
- width_in = gr.Slider(256, 1024, step=8, value=768, label="Width")
98
- height_in = gr.Slider(256, 1024, step=8, value=512, label="Height")
99
 
100
  with gr.Row():
101
- frames_in = gr.Slider(17, 241, step=2, value=65, label="Frames (num_frames)")
102
- fps_in = gr.Slider(8, 30, step=1, value=24, label="FPS")
103
 
104
  with gr.Row():
105
- dt_in = gr.Slider(0.0, 0.2, step=0.001, value=0.03, label="decode_timestep")
106
- dns_in = gr.Slider(0.0, 0.2, step=0.001, value=0.025, label="decode_noise_scale")
107
- steps_in = gr.Slider(10, 75, step=1, value=40, label="Inference Steps")
108
- seed_in = gr.Number(value=-1, label="Seed (>=0 κ³ μ •)")
109
 
110
- btn = gr.Button("πŸŽ₯ Generate Video", variant="primary")
111
  video_out = gr.Video(label="Output", autoplay=True)
112
- info_out = gr.Markdown()
113
 
114
- btn.click(fn=generate,
115
- inputs=[prompt_in, neg_in, width_in, height_in,
116
- frames_in, fps_in, dt_in, dns_in, steps_in, seed_in],
117
- outputs=[video_out, info_out])
 
118
 
119
  demo.queue().launch()
 
1
+ import os, tempfile
2
+ import numpy as np
3
+ import torch
4
  import gradio as gr
5
+
6
  from diffusers import LTXPipeline, AutoModel
7
  from diffusers.hooks import apply_group_offloading
8
  from diffusers.utils import export_to_video
9
 
10
+ # -------------------------------------------------------------------
11
+ # ν™˜κ²½ μ˜μ‘΄μ„± μ°Έκ³ :
12
+ # pip install -U torch torchvision accelerate transformers diffusers safetensors sentencepiece gradio imageio imageio-ffmpeg
13
+ # (Spaces/도컀라면 ffmpeg λ°”μ΄λ„ˆλ¦¬ ν•„μš”ν•  수 있음: apt-get update && apt-get install -y ffmpeg)
14
+ # -------------------------------------------------------------------
15
+
16
+ def load_pipeline(device: str = "cuda"):
17
+ """
18
+ LTX-Video νŒŒμ΄ν”„λΌμΈ λ‘œλ“œ:
19
+ - sentencepiece ν•„μš” (T5 ν† ν¬λ‚˜μ΄μ €)
20
+ - trust_remote_code=True (Placeholder 이슈 λ°©μ§€)
21
+ - bf16/FP8/μ˜€ν”„λ‘œλ”©μ€ κ°€λŠ₯ν•œ κ²½μš°μ—λ§Œ ν™œμ„±ν™”
22
+ """
23
+ use_cuda = torch.cuda.is_available()
24
+ device = "cuda" if use_cuda else "cpu"
25
+ dtype = torch.bfloat16 if use_cuda else torch.float16 # bf16은 CUDA일 λ•Œλ§Œ 의미
26
+
27
+ # 1) Transformer λ‘œλ“œ
28
  transformer = AutoModel.from_pretrained(
29
  "Lightricks/LTX-Video",
30
  subfolder="transformer",
31
  torch_dtype=dtype,
32
+ trust_remote_code=True,
33
+ variant="bf16" if dtype == torch.bfloat16 else None
34
  )
35
 
36
+ # 2) FP8 layerwise casting (κ°€λŠ₯ν•œ 경우만)
37
+ fp8_ok = False
38
  try:
39
  transformer.enable_layerwise_casting(
40
+ storage_dtype=torch.float8_e4m3fn,
41
+ compute_dtype=dtype
42
  )
43
+ fp8_ok = True
44
  except Exception:
45
+ fp8_ok = False # ν™˜κ²½ 미지원 μ‹œ 쑰용히 패슀
46
 
47
+ # 3) Pipeline λ‘œλ“œ
48
  pipe = LTXPipeline.from_pretrained(
49
  "Lightricks/LTX-Video",
50
  transformer=transformer,
51
  torch_dtype=dtype,
52
  trust_remote_code=True,
53
+ variant="bf16" if dtype == torch.bfloat16 else None
54
  ).to(device)
55
 
56
+ # 4) κ·Έλ£Ή μ˜€ν”„λ‘œλ”© (κ°€λŠ₯ν•œ 경우만)
57
+ offload_ok = False
58
  try:
59
  onload_device = torch.device(device)
60
  offload_device = torch.device("cpu")
 
64
  offload_type="leaf_level",
65
  use_stream=True
66
  )
67
+ apply_group_offloading(
68
+ pipe.text_encoder,
69
+ onload_device=onload_device,
70
+ offload_type="block_level",
71
+ num_blocks_per_group=2
72
+ )
73
+ apply_group_offloading(
74
+ pipe.vae,
75
+ onload_device=onload_device,
76
+ offload_type="leaf_level"
77
+ )
78
+ offload_ok = True
79
  except Exception:
80
+ offload_ok = False
81
 
82
+ return pipe, fp8_ok, offload_ok, device
83
 
 
84
 
85
+ PIPE, FP8_OK, OFFLOAD_OK, DEVICE = load_pipeline()
 
 
 
 
86
 
87
+ def _to_uint8_frames(frames):
88
+ """
89
+ (T,H,W,C) float/torch ν…μ„œλ₯Ό μ•ˆμ „ν•˜κ²Œ uint8 numpy둜 λ³€ν™˜
90
+ """
91
+ import numpy as np
92
+
93
+ if isinstance(frames, torch.Tensor):
94
+ frames = frames.detach().to("cpu").numpy()
95
+
96
+ if frames.ndim == 3:
97
+ # (T,H,W) -> (T,H,W,1)
98
+ frames = frames[..., None]
99
+
100
+ assert frames.ndim == 4, f"Unexpected frames shape: {frames.shape}"
101
+
102
+ if frames.dtype != np.uint8:
103
+ # 0~1 λ˜λŠ” 0~255 λ²”μœ„μ— 맞좰 μŠ€μΌ€μΌλ§
104
+ mx = frames.max()
105
+ if mx <= 1.0:
106
+ frames = (np.clip(frames, 0, 1) * 255).astype(np.uint8)
107
+ else:
108
+ frames = np.clip(frames, 0, 255).astype(np.uint8)
109
 
110
+ return frames
111
+
112
+
113
+ def generate_video(
114
+ prompt, negative_prompt,
115
+ width, height, num_frames, fps,
116
+ decode_timestep, decode_noise_scale,
117
+ steps, seed
118
+ ):
119
+ # μ‹œλ“œ
120
+ g = None
121
+ if seed is not None:
122
+ try:
123
+ s = int(seed)
124
+ if s >= 0:
125
+ g = torch.Generator(device=DEVICE).manual_seed(s)
126
+ except Exception:
127
+ pass
128
+
129
+ # μΆ”λ‘ 
130
  with torch.inference_mode():
131
+ out = PIPE(
132
+ prompt=(prompt or "").strip(),
133
+ negative_prompt=(negative_prompt or "").strip() or None,
134
+ width=int(width),
135
+ height=int(height),
136
+ num_frames=int(num_frames),
137
+ fps=int(fps),
138
+ decode_timestep=float(decode_timestep),
139
+ decode_noise_scale=float(decode_noise_scale),
140
+ num_inference_steps=int(steps),
141
  generator=g
142
  )
143
+ frames = out.frames[0] # μ˜ˆμƒ: (T, H, W, C) float / torch
144
 
145
+ # ν”„λ ˆμž„μ„ μ•ˆμ „ν•œ ν˜•μ‹μœΌλ‘œ λ³€ν™˜
146
+ frames = _to_uint8_frames(frames)
147
+
148
+ # μ €μž₯ 경둜
149
  tmpdir = tempfile.mkdtemp()
150
  save_path = os.path.join(tmpdir, "output.mp4")
 
 
151
 
152
+ # 1μˆœμœ„: diffusers λ‚΄μž₯ saver
153
+ try:
154
+ export_to_video(frames, save_path, fps=int(fps))
155
+ except Exception:
156
+ # 폴백: imageio-ffmpeg
157
+ import imageio.v3 as iio
158
+ iio.imwrite(save_path, frames, fps=int(fps), codec="libx264")
159
+
160
+ info = (
161
+ f"FP8: {'ON' if FP8_OK else 'OFF'} | "
162
+ f"Offloading: {'ON' if OFFLOAD_OK else 'OFF'} | "
163
+ f"Device: {DEVICE} | "
164
+ f"Frames: {frames.shape} | FPS: {int(fps)}"
165
+ )
166
+ return save_path, info
167
+
168
+
169
+ # --------------------------- Gradio UI ---------------------------
170
  with gr.Blocks(title="LTX-Video Gradio") as demo:
171
+ gr.Markdown("## 🎬 LTX-Video β€” Prompt to Short Video")
172
 
173
  with gr.Row():
174
+ prompt_in = gr.Textbox(
175
+ label="Prompt",
176
+ lines=6,
177
+ value="A cinematic close-up of a smiling woman under warm sunset light."
178
+ )
179
+ neg_in = gr.Textbox(
180
+ label="Negative Prompt",
181
+ lines=4,
182
+ value="worst quality, inconsistent motion, blurry, jittery, distorted"
183
+ )
184
 
185
  with gr.Row():
186
+ width_in = gr.Slider(256, 1024, value=768, step=8, label="Width")
187
+ height_in = gr.Slider(256, 1024, value=512, step=8, label="Height")
188
 
189
  with gr.Row():
190
+ frames_in = gr.Slider(17, 241, value=65, step=2, label="num_frames")
191
+ fps_in = gr.Slider(8, 30, value=24, step=1, label="FPS")
192
 
193
  with gr.Row():
194
+ dt_in = gr.Slider(0.0, 0.2, value=0.03, step=0.001, label="decode_timestep")
195
+ dns_in = gr.Slider(0.0, 0.2, value=0.025, step=0.001, label="decode_noise_scale")
196
+ steps_in = gr.Slider(10, 75, value=40, step=1, label="num_inference_steps")
197
+ seed_in = gr.Number(value=-1, label="Seed (>=0 to fix)")
198
 
199
+ gen_btn = gr.Button("πŸŽ₯ Generate", variant="primary")
200
  video_out = gr.Video(label="Output", autoplay=True)
201
+ info_out = gr.Markdown()
202
 
203
+ gen_btn.click(
204
+ fn=generate_video,
205
+ inputs=[prompt_in, neg_in, width_in, height_in, frames_in, fps_in, dt_in, dns_in, steps_in, seed_in],
206
+ outputs=[video_out, info_out]
207
+ )
208
 
209
  demo.queue().launch()