multimodalart HF Staff commited on
Commit
1b86783
·
verified ·
1 Parent(s): d7a915c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -101
app.py CHANGED
@@ -24,6 +24,7 @@ import random
24
  import argparse
25
  import hashlib
26
  import urllib.request
 
27
  from PIL import Image
28
  import spaces
29
  import numpy as np
@@ -31,27 +32,23 @@ import torch
31
  import gradio as gr
32
  from omegaconf import OmegaConf
33
  from tqdm import tqdm
34
- import imageio # Added for final video rendering
35
-
36
- # FastRTC imports
37
- from fastrtc import WebRTC, get_cloudflare_turn_credentials
38
- from fastrtc.utils import AdditionalOutputs #, CloseStream
39
 
40
  # Original project imports
41
  from pipeline import CausalInferencePipeline
42
  from demo_utils.constant import ZERO_VAE_CACHE
43
  from demo_utils.vae_block3 import VAEDecoderWrapper
44
  from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
45
- # from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller
46
 
47
  # --- Argument Parsing ---
48
- parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with FastRTC")
49
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
50
  parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
51
  parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.")
52
  parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
53
  parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
54
  parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
 
55
  args = parser.parse_args()
56
 
57
  gpu = "cuda"
@@ -146,24 +143,22 @@ pipeline = CausalInferencePipeline(
146
 
147
  pipeline.to(dtype=torch.float16).to(gpu)
148
 
149
- # --- Additional Outputs Handler ---
150
- def handle_additional_outputs(status_html_update, video_update, webrtc_output):
151
- return status_html_update, video_update, webrtc_output
152
-
153
- # --- FastRTC Video Generation Handler ---
154
  @torch.no_grad()
155
  @spaces.GPU
156
- def video_generation_handler(prompt, seed, progress=gr.Progress()):
157
  """
158
- Generator function that yields BGR NumPy frames for real-time streaming.
159
- Returns cleanly when done - no infinite loops.
160
  """
161
-
162
  if seed == -1:
163
  seed = random.randint(0, 2**32 - 1)
164
 
165
  print(f"🎬 Starting video generation with prompt: '{prompt}' and seed: {seed}")
166
-
 
 
 
167
  print("🔤 Encoding text prompt...")
168
  conditional_dict = text_encoder(text_prompts=[prompt])
169
  for key, value in conditional_dict.items():
@@ -184,7 +179,7 @@ def video_generation_handler(prompt, seed, progress=gr.Progress()):
184
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
185
 
186
  total_frames_yielded = 0
187
- all_frames_for_video = [] # To collect frames for final video
188
 
189
  for idx, current_num_frames in enumerate(all_num_frames):
190
  print(f"📦 Processing block {idx+1}/{num_blocks} with {current_num_frames} frames")
@@ -235,7 +230,7 @@ def video_generation_handler(prompt, seed, progress=gr.Progress()):
235
 
236
  print(f"📹 Decoded pixels shape: {pixels.shape}")
237
 
238
- # Yield individual frames WITH status updates
239
  for frame_idx in range(pixels.shape[1]):
240
  frame_tensor = pixels[0, frame_idx] # Get single frame [C, H, W]
241
 
@@ -243,73 +238,47 @@ def video_generation_handler(prompt, seed, progress=gr.Progress()):
243
  frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
244
  frame_np = frame_np.to(torch.uint8).cpu().numpy()
245
 
246
- # Convert from CHW to HWC format
247
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
248
 
249
  all_frames_for_video.append(frame_np)
250
-
251
- # Convert RGB to BGR for FastRTC (OpenCV format)
252
- frame_bgr = frame_np[:, :, ::-1] # RGB -> BGR
253
-
254
  total_frames_yielded += 1
255
- print(f"📺 Yielding frame {total_frames_yielded}: shape {frame_bgr.shape}, dtype {frame_bgr.dtype}")
256
 
257
  # Calculate progress
258
  total_expected_frames = num_blocks * pipeline.num_frame_per_block
259
  current_frame_count = (idx * pipeline.num_frame_per_block) + frame_idx + 1
260
- frame_progress = 100 * (current_frame_count / total_expected_frames)
261
-
262
- # --- REVISED HTML START ---
263
- if frame_idx == pixels.shape[1] - 1 and idx + 1 == num_blocks: # last frame
264
- status_html = (
265
- f"<div style='padding: 16px; border: 1px solid #198754; background-color: #d1e7dd; border-radius: 8px; font-family: sans-serif; text-align: center;'>"
266
- f" <h4 style='margin: 0 0 8px 0; color: #0f5132; font-size: 18px;'>🎉 Generation Complete!</h4>"
267
- f" <p style='margin: 0; color: #0f5132;'>"
268
- f" Total frames: {total_frames_yielded}. The final video is now available."
269
- f" </p>"
270
- f"</div>"
271
- )
272
-
273
- print("💾 Saving final rendered video...")
274
- video_update = gr.update() # Default to no-op
275
- try:
276
- video_path = f"gradio_tmp/{seed}_{hashlib.md5(prompt.encode()).hexdigest()}.mp4"
277
- imageio.mimwrite(video_path, all_frames_for_video, fps=15, quality=8)
278
- print(f"✅ Video saved to {video_path}")
279
- video_update = gr.update(value=video_path, visible=True)
280
- except Exception as e:
281
- print(f"⚠️ Could not save final video: {e}")
282
-
283
- yield frame_bgr, AdditionalOutputs(status_html, video_update, gr.update(visible=False))
284
- # yield CloseStream("🎉 Video generation completed successfully!")
285
- return
286
- else: # Regular frames - simpler status
287
- status_html = (
288
- f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
289
- f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
290
- f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
291
- f" <div style='width: {frame_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
292
- f" </div>"
293
- f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
294
- f" Block {idx+1}/{num_blocks}   |   Frame {total_frames_yielded}   |   {frame_progress:.1f}%"
295
- f" </p>"
296
- f"</div>"
297
- )
298
- # --- REVISED HTML END ---
299
-
300
- yield frame_bgr, AdditionalOutputs(status_html, gr.update(visible=False), gr.update(visible=True))
301
 
302
  current_start_frame += current_num_frames
303
 
304
  print(f"✅ Video generation completed! Total frames yielded: {total_frames_yielded}")
305
 
306
- # Signal completion
307
- # yield CloseStream("🎉 Video generation completed successfully!")
 
 
 
 
 
 
 
308
 
309
  # --- Gradio UI Layout ---
310
- with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing FastRTC Demo") as demo:
311
- gr.Markdown("# 🚀 Self-Forcing Video Generation with FastRTC Streaming")
312
- gr.Markdown("*Real-time video generation streaming via WebRTC*")
313
 
314
  with gr.Row():
315
  with gr.Column(scale=2):
@@ -332,47 +301,42 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing FastRTC Demo") as dem
332
 
333
  with gr.Row():
334
  seed = gr.Number(label="Seed", value=-1, info="Use -1 for a random seed.")
335
-
336
- with gr.Accordion("⚙️ Performance Options", open=False):
337
- gr.Markdown("*These optimizations are applied once per session*")
 
 
 
 
 
338
 
339
  start_btn = gr.Button("🎬 Start Generation", variant="primary", size="lg")
340
 
341
  with gr.Column(scale=3):
342
- gr.Markdown("### 📺 Live Video Stream")
343
- gr.Markdown("*Click 'Start Generation' to begin streaming*")
344
 
345
- webrtc_output = WebRTC(
346
- label="Generated Video Stream",
347
- modality="video",
348
- mode="receive", # Server sends video to client
349
  height=480,
350
  width=832,
351
- rtc_configuration=get_cloudflare_turn_credentials(),
352
- elem_id="video_stream"
353
  )
354
-
355
- final_video = gr.Video(label="Final Rendered Video", visible=False, interactive=False)
356
 
357
- status_html = gr.HTML(
358
- value="<div style='text-align: center; padding: 20px; color: #666;'>Ready to start generation...</div>",
359
- label="Generation Status"
 
 
360
  )
361
 
362
-
363
-
364
- # Connect the generator to the WebRTC stream
365
- webrtc_output.stream(
366
- fn=video_generation_handler,
367
- inputs=[prompt, seed],
368
- outputs=[webrtc_output],
369
- time_limit=300, # 5 minutes max
370
- trigger=start_btn.click,
371
- )
372
- # MODIFIED: Handle additional outputs (status updates AND final video)
373
- webrtc_output.on_additional_outputs(
374
- fn=handle_additional_outputs,
375
- outputs=[status_html, final_video, webrtc_output]
376
  )
377
 
378
  # --- Launch App ---
 
24
  import argparse
25
  import hashlib
26
  import urllib.request
27
+ import time
28
  from PIL import Image
29
  import spaces
30
  import numpy as np
 
32
  import gradio as gr
33
  from omegaconf import OmegaConf
34
  from tqdm import tqdm
35
+ import imageio
 
 
 
 
36
 
37
  # Original project imports
38
  from pipeline import CausalInferencePipeline
39
  from demo_utils.constant import ZERO_VAE_CACHE
40
  from demo_utils.vae_block3 import VAEDecoderWrapper
41
  from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
 
42
 
43
  # --- Argument Parsing ---
44
+ parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
45
  parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
46
  parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
47
  parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.")
48
  parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
49
  parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
50
  parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
51
+ parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
52
  args = parser.parse_args()
53
 
54
  gpu = "cuda"
 
143
 
144
  pipeline.to(dtype=torch.float16).to(gpu)
145
 
146
+ # --- Frame Streaming Video Generation Handler ---
 
 
 
 
147
  @torch.no_grad()
148
  @spaces.GPU
149
+ def video_generation_handler(prompt, seed, fps, progress=gr.Progress()):
150
  """
151
+ Generator function that yields RGB frames for display in gr.Image.
152
+ Includes timing delays for smooth playback.
153
  """
 
154
  if seed == -1:
155
  seed = random.randint(0, 2**32 - 1)
156
 
157
  print(f"🎬 Starting video generation with prompt: '{prompt}' and seed: {seed}")
158
+
159
+ # Calculate frame delay based on FPS
160
+ frame_delay = 1.0 / fps if fps > 0 else 1.0 / 15.0
161
+
162
  print("🔤 Encoding text prompt...")
163
  conditional_dict = text_encoder(text_prompts=[prompt])
164
  for key, value in conditional_dict.items():
 
179
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
180
 
181
  total_frames_yielded = 0
182
+ all_frames_for_video = []
183
 
184
  for idx, current_num_frames in enumerate(all_num_frames):
185
  print(f"📦 Processing block {idx+1}/{num_blocks} with {current_num_frames} frames")
 
230
 
231
  print(f"📹 Decoded pixels shape: {pixels.shape}")
232
 
233
+ # Yield individual frames with timing delays
234
  for frame_idx in range(pixels.shape[1]):
235
  frame_tensor = pixels[0, frame_idx] # Get single frame [C, H, W]
236
 
 
238
  frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
239
  frame_np = frame_np.to(torch.uint8).cpu().numpy()
240
 
241
+ # Convert from CHW to HWC format (RGB)
242
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
243
 
244
  all_frames_for_video.append(frame_np)
 
 
 
 
245
  total_frames_yielded += 1
 
246
 
247
  # Calculate progress
248
  total_expected_frames = num_blocks * pipeline.num_frame_per_block
249
  current_frame_count = (idx * pipeline.num_frame_per_block) + frame_idx + 1
250
+ frame_progress = current_frame_count / total_expected_frames
251
+
252
+ # Update progress
253
+ progress(frame_progress, desc=f"Frame {total_frames_yielded} | Block {idx+1}/{num_blocks}")
254
+
255
+ print(f"📺 Yielding frame {total_frames_yielded}: shape {frame_np.shape}")
256
+
257
+ # Yield frame with timing delay
258
+ yield gr.update(visible=True, frame_np), gr.update(visible=False)
259
+
260
+ # Sleep between frames for smooth playback (except for the last frame)
261
+ if not (frame_idx == pixels.shape[1] - 1 and idx + 1 == num_blocks):
262
+ time.sleep(frame_delay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  current_start_frame += current_num_frames
265
 
266
  print(f"✅ Video generation completed! Total frames yielded: {total_frames_yielded}")
267
 
268
+ # Save final video
269
+ try:
270
+ video_path = f"gradio_tmp/{seed}_{hashlib.md5(prompt.encode()).hexdigest()}.mp4"
271
+ imageio.mimwrite(video_path, all_frames_for_video, fps=fps, quality=8)
272
+ print(f"✅ Video saved to {video_path}")
273
+ return gr.update(visible=False), gr.update(value=video_path, visible=True)
274
+ except Exception as e:
275
+ print(f"⚠️ Could not save final video: {e}")
276
+ return None, None
277
 
278
  # --- Gradio UI Layout ---
279
+ with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing Frame Streaming Demo") as demo:
280
+ gr.Markdown("# 🚀 Self-Forcing Video Generation with Frame Streaming")
281
+ gr.Markdown("*Real-time video generation with frame-by-frame display*")
282
 
283
  with gr.Row():
284
  with gr.Column(scale=2):
 
301
 
302
  with gr.Row():
303
  seed = gr.Number(label="Seed", value=-1, info="Use -1 for a random seed.")
304
+ fps = gr.Slider(
305
+ label="Playback FPS",
306
+ minimum=1,
307
+ maximum=30,
308
+ value=args.fps,
309
+ step=1,
310
+ info="Frames per second for playback"
311
+ )
312
 
313
  start_btn = gr.Button("🎬 Start Generation", variant="primary", size="lg")
314
 
315
  with gr.Column(scale=3):
316
+ gr.Markdown("### 📺 Live Frame Stream")
317
+ gr.Markdown("*Click 'Start Generation' to begin frame streaming*")
318
 
319
+ frame_display = gr.Image(
320
+ label="Generated Frames",
 
 
321
  height=480,
322
  width=832,
323
+ show_label=True,
324
+ container=True
325
  )
 
 
326
 
327
+ final_video = gr.Video(
328
+ label="Final Rendered Video",
329
+ visible=True,
330
+ interactive=False,
331
+ height=400
332
  )
333
 
334
+ # Connect the generator to the image display
335
+ start_btn.click(
336
+ fn=video_generation_handler,
337
+ inputs=[prompt, seed, fps],
338
+ outputs=[frame_display, final_video],
339
+ show_progress="full"
 
 
 
 
 
 
 
 
340
  )
341
 
342
  # --- Launch App ---