tsi-org commited on
Commit
94ff503
Β·
verified Β·
1 Parent(s): 607da1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -135
app.py CHANGED
@@ -141,26 +141,12 @@ transformer.eval().to(dtype=torch.float16).requires_grad_(False)
141
  text_encoder.to(gpu)
142
  transformer.to(gpu)
143
 
144
- APP_STATE = {
145
- "torch_compile_applied": False,
146
- "fp8_applied": False,
147
- "current_use_taehv": False,
148
- "current_vae_decoder": None,
149
- "last_generated_frames": [], # Store frames for download
150
- "last_generation_info": {} # Store metadata
151
- }
152
 
153
  def frames_to_ts_file(frames, filepath, fps = 15):
154
  """
155
  Convert frames directly to .ts file using PyAV.
156
-
157
- Args:
158
- frames: List of numpy arrays (HWC, RGB, uint8)
159
- filepath: Output file path
160
- fps: Frames per second
161
-
162
- Returns:
163
- The filepath of the created file
164
  """
165
  if not frames:
166
  return filepath
@@ -200,83 +186,52 @@ def frames_to_ts_file(frames, filepath, fps = 15):
200
 
201
  return filepath
202
 
203
- def frames_to_mp4_file(frames, filepath, fps=15):
204
  """
205
- Convert frames to MP4 file using PyAV for download.
206
-
207
- Args:
208
- frames: List of numpy arrays (HWC, RGB, uint8)
209
- filepath: Output file path
210
- fps: Frames per second
211
-
212
- Returns:
213
- The filepath of the created file
214
  """
215
- if not frames:
216
- return filepath
217
 
218
- height, width = frames[0].shape[:2]
 
 
 
 
 
 
219
 
220
- # Create container for MP4 format
221
- container = av.open(filepath, mode='w', format='mp4')
 
 
 
 
222
 
223
- # Add video stream with high quality settings for download
224
- stream = container.add_stream('h264', rate=fps)
225
- stream.width = width
226
- stream.height = height
227
- stream.pix_fmt = 'yuv420p'
228
 
229
- # High quality settings for download
230
- stream.options = {
231
- 'preset': 'medium',
232
- 'crf': '18', # Higher quality
233
- 'profile': 'high',
234
- 'level': '4.0'
235
- }
236
 
237
- try:
238
- for frame_np in frames:
239
- frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
240
- frame = frame.reformat(format=stream.pix_fmt)
241
- for packet in stream.encode(frame):
242
- container.mux(packet)
243
-
244
- for packet in stream.encode():
245
- container.mux(packet)
246
-
247
- finally:
248
- container.close()
249
-
250
- return filepath
251
 
252
- def create_download_video():
253
  """
254
- Create a downloadable MP4 file from the last generated frames.
255
  """
256
- if not APP_STATE["last_generated_frames"]:
257
  return None
258
 
259
  try:
260
- # Create downloads directory if it doesn't exist
261
- os.makedirs("downloads", exist_ok=True)
 
 
262
 
263
- # Generate filename with timestamp and prompt hash
264
- timestamp = int(time.time())
265
- prompt_hash = hashlib.md5(APP_STATE["last_generation_info"].get("prompt", "").encode()).hexdigest()[:8]
266
- filename = f"pixio_video_{timestamp}_{prompt_hash}.mp4"
267
- filepath = os.path.join("downloads", filename)
268
-
269
- # Create MP4 file
270
- fps = APP_STATE["last_generation_info"].get("fps", 15)
271
- frames_to_mp4_file(APP_STATE["last_generated_frames"], filepath, fps)
272
-
273
- print(f"βœ… Download video created: {filepath}")
274
  return filepath
275
 
276
  except Exception as e:
277
- print(f"❌ Error creating download video: {e}")
278
- import traceback
279
- traceback.print_exc()
280
  return None
281
 
282
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
@@ -326,6 +281,13 @@ def initialize_vae_decoder(use_taehv=False, use_trt=False):
326
  APP_STATE["current_vae_decoder"] = vae_decoder
327
  print(f"βœ… VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
328
 
 
 
 
 
 
 
 
329
  # Initialize with default VAE
330
  initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
331
 
@@ -340,17 +302,14 @@ pipeline.to(dtype=torch.float16).to(gpu)
340
  @spaces.GPU
341
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
342
  """
343
- Generator function that yields .ts video chunks using PyAV for streaming.
344
- Now optimized for block-based processing and stores frames for download.
345
  """
 
 
346
  if seed == -1:
347
  seed = random.randint(0, 2**32 - 1)
348
 
349
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
350
-
351
- # Clear previous generation data
352
- APP_STATE["last_generated_frames"] = []
353
- APP_STATE["last_generation_info"] = {"prompt": prompt, "seed": seed, "fps": fps}
354
 
355
  # Setup
356
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -371,9 +330,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
371
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
372
 
373
  total_frames_yielded = 0
 
 
374
 
375
- # Ensure temp directory exists
376
- os.makedirs("gradio_tmp", exist_ok=True)
 
 
 
377
 
378
  # Generation loop
379
  for idx, current_num_frames in enumerate(all_num_frames):
@@ -424,10 +388,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
424
  elif APP_STATE["current_use_taehv"] and idx > 0:
425
  pixels = pixels[:, 12:]
426
 
427
- print(f"πŸ” DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
428
-
429
- # Process all frames from this block at once
430
- all_frames_from_block = []
431
  for frame_idx in range(pixels.shape[1]):
432
  frame_tensor = pixels[0, frame_idx]
433
 
@@ -436,17 +398,14 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
436
  frame_np = frame_np.to(torch.uint8).cpu().numpy()
437
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
438
 
439
- all_frames_from_block.append(frame_np)
440
- # Store frame for download
441
- APP_STATE["last_generated_frames"].append(frame_np)
442
  total_frames_yielded += 1
443
 
444
- # Yield status update for each frame (cute tracking!)
445
  blocks_completed = idx
446
  current_block_progress = (frame_idx + 1) / pixels.shape[1]
447
  total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
448
-
449
- # Cap at 100% to avoid going over
450
  total_progress = min(total_progress, 100.0)
451
 
452
  frame_status_html = (
@@ -461,34 +420,51 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
461
  f"</div>"
462
  )
463
 
464
- # Yield None for video but update status (frame-by-frame tracking)
465
- yield None, frame_status_html, gr.update(visible=False)
466
 
467
- # Encode entire block as one chunk immediately
468
- if all_frames_from_block:
469
- print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
470
-
471
  try:
472
- chunk_uuid = str(uuid.uuid4())[:8]
473
- ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
474
- ts_path = os.path.join("gradio_tmp", ts_filename)
475
 
476
- frames_to_ts_file(all_frames_from_block, ts_path, fps)
 
477
 
478
- # Calculate final progress for this block
479
- total_progress = (idx + 1) / num_blocks * 100
 
480
 
481
- # Yield the actual video chunk
482
- yield ts_path, gr.update(), gr.update(visible=False)
483
 
484
  except Exception as e:
485
- print(f"⚠️ Error encoding block {idx}: {e}")
486
- import traceback
487
- traceback.print_exc()
488
 
489
  current_start_frame += current_num_frames
490
 
491
- # Final completion status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
  final_status_html = (
493
  f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
494
  f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
@@ -500,27 +476,33 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
500
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
501
  f" </p>"
502
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
503
- f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264"
504
  f" </p>"
505
  f" </div>"
506
  f"</div>"
507
  )
508
- yield None, final_status_html, gr.update(visible=True)
509
- print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
 
 
 
 
 
 
510
 
511
  # --- Gradio UI Layout ---
512
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
513
  gr.Markdown("# πŸš€ Pixio Streaming Video Generation")
514
- gr.Markdown("Real-time video generation with Pixio), [[Project page]](https://pixio.myapps.ai) )")
515
 
516
  with gr.Row():
517
  with gr.Column(scale=2):
518
  with gr.Group():
519
  prompt = gr.Textbox(
520
  label="Prompt",
521
- placeholder="A stylish woman walks down a Tokyo street...",
522
  lines=4,
523
- value=""
524
  )
525
  enhance_button = gr.Button("✨ Enhance Prompt", variant="secondary")
526
 
@@ -576,20 +558,20 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
576
  label="Generation Status"
577
  )
578
 
579
- # Download button (initially hidden)
580
- download_btn = gr.DownloadButton(
581
- label="πŸ“₯ Download MP4",
582
- value=create_download_video,
583
- variant="secondary",
584
- size="lg",
585
- visible=False
586
- )
587
 
588
- # Connect the generator to the streaming video
589
  start_btn.click(
590
  fn=video_generation_handler_streaming,
591
  inputs=[prompt, seed, fps],
592
- outputs=[streaming_video, status_display, download_btn]
593
  )
594
 
595
  enhance_button.click(
@@ -607,18 +589,17 @@ if __name__ == "__main__":
607
  os.makedirs("downloads", exist_ok=True)
608
 
609
  print("πŸš€ Starting Self-Forcing Streaming Demo")
610
- print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
611
- print(f"πŸ“₯ Download files will be stored in: downloads/")
612
- print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
613
- print(f"⚑ GPU acceleration: {gpu}")
614
 
615
  demo.queue().launch(
616
  server_name=args.host,
617
  server_port=args.port,
618
  share=args.share,
619
  show_error=True,
620
- max_threads=40,
621
- mcp_server=True
622
  )
623
 
624
  # import subprocess
 
141
  text_encoder.to(gpu)
142
  transformer.to(gpu)
143
 
144
+ # Global state for download
145
+ CURRENT_DOWNLOAD_PATH = None
 
 
 
 
 
 
146
 
147
  def frames_to_ts_file(frames, filepath, fps = 15):
148
  """
149
  Convert frames directly to .ts file using PyAV.
 
 
 
 
 
 
 
 
150
  """
151
  if not frames:
152
  return filepath
 
186
 
187
  return filepath
188
 
189
+ def create_hls_playlist(ts_files, playlist_path, fps=15):
190
  """
191
+ Create HLS playlist (.m3u8) file for streaming.
 
 
 
 
 
 
 
 
192
  """
193
+ segment_duration = 1.0 # Each segment duration in seconds
 
194
 
195
+ playlist_content = [
196
+ "#EXTM3U",
197
+ "#EXT-X-VERSION:3",
198
+ f"#EXT-X-TARGETDURATION:{int(segment_duration) + 1}",
199
+ "#EXT-X-MEDIA-SEQUENCE:0",
200
+ "#EXT-X-PLAYLIST-TYPE:VOD"
201
+ ]
202
 
203
+ for ts_file in ts_files:
204
+ ts_filename = os.path.basename(ts_file)
205
+ playlist_content.extend([
206
+ f"#EXTINF:{segment_duration:.1f},",
207
+ ts_filename
208
+ ])
209
 
210
+ playlist_content.append("#EXT-X-ENDLIST")
 
 
 
 
211
 
212
+ with open(playlist_path, 'w') as f:
213
+ f.write('\n'.join(playlist_content))
 
 
 
 
 
214
 
215
+ return playlist_path
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
+ def frames_to_mp4_file(frames, filepath, fps=15):
218
  """
219
+ Convert frames to MP4 file using imageio.
220
  """
221
+ if not frames:
222
  return None
223
 
224
  try:
225
+ # Use imageio for reliable MP4 creation
226
+ with imageio.get_writer(filepath, fps=fps, codec='libx264', quality=8) as writer:
227
+ for frame in frames:
228
+ writer.append_data(frame)
229
 
230
+ print(f"βœ… MP4 created successfully: {filepath}")
 
 
 
 
 
 
 
 
 
 
231
  return filepath
232
 
233
  except Exception as e:
234
+ print(f"❌ Error creating MP4: {e}")
 
 
235
  return None
236
 
237
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
 
281
  APP_STATE["current_vae_decoder"] = vae_decoder
282
  print(f"βœ… VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
283
 
284
+ APP_STATE = {
285
+ "torch_compile_applied": False,
286
+ "fp8_applied": False,
287
+ "current_use_taehv": False,
288
+ "current_vae_decoder": None,
289
+ }
290
+
291
  # Initialize with default VAE
292
  initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
293
 
 
302
  @spaces.GPU
303
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
304
  """
305
+ Generator function that creates HLS stream and final MP4.
 
306
  """
307
+ global CURRENT_DOWNLOAD_PATH
308
+
309
  if seed == -1:
310
  seed = random.randint(0, 2**32 - 1)
311
 
312
+ print(f"🎬 Starting HLS streaming: '{prompt}', seed: {seed}")
 
 
 
 
313
 
314
  # Setup
315
  conditional_dict = text_encoder(text_prompts=[prompt])
 
330
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
331
 
332
  total_frames_yielded = 0
333
+ all_frames_for_download = [] # Store frames for final MP4
334
+ ts_files = [] # Store TS files for HLS playlist
335
 
336
+ # Create unique session directory
337
+ session_id = str(uuid.uuid4())[:8]
338
+ session_dir = os.path.join("gradio_tmp", f"session_{session_id}")
339
+ os.makedirs(session_dir, exist_ok=True)
340
+ os.makedirs("downloads", exist_ok=True)
341
 
342
  # Generation loop
343
  for idx, current_num_frames in enumerate(all_num_frames):
 
388
  elif APP_STATE["current_use_taehv"] and idx > 0:
389
  pixels = pixels[:, 12:]
390
 
391
+ # Process frames from this block
392
+ block_frames = []
 
 
393
  for frame_idx in range(pixels.shape[1]):
394
  frame_tensor = pixels[0, frame_idx]
395
 
 
398
  frame_np = frame_np.to(torch.uint8).cpu().numpy()
399
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
400
 
401
+ block_frames.append(frame_np)
402
+ all_frames_for_download.append(frame_np) # Store for final MP4
 
403
  total_frames_yielded += 1
404
 
405
+ # Progress tracking
406
  blocks_completed = idx
407
  current_block_progress = (frame_idx + 1) / pixels.shape[1]
408
  total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
 
 
409
  total_progress = min(total_progress, 100.0)
410
 
411
  frame_status_html = (
 
420
  f"</div>"
421
  )
422
 
423
+ yield None, frame_status_html
 
424
 
425
+ # Create TS segment for this block
426
+ if block_frames:
 
 
427
  try:
428
+ ts_filename = f"segment_{idx:04d}.ts"
429
+ ts_path = os.path.join(session_dir, ts_filename)
 
430
 
431
+ frames_to_ts_file(block_frames, ts_path, fps)
432
+ ts_files.append(ts_path)
433
 
434
+ # Create/update HLS playlist
435
+ playlist_path = os.path.join(session_dir, "playlist.m3u8")
436
+ create_hls_playlist(ts_files, playlist_path, fps)
437
 
438
+ # Yield the HLS playlist for streaming
439
+ yield playlist_path, gr.update()
440
 
441
  except Exception as e:
442
+ print(f"⚠️ Error creating HLS segment {idx}: {e}")
 
 
443
 
444
  current_start_frame += current_num_frames
445
 
446
+ # Create final MP4 for download
447
+ print("🎬 Creating final MP4 for download...")
448
+ try:
449
+ timestamp = int(time.time())
450
+ prompt_hash = hashlib.md5(prompt.encode()).hexdigest()[:8]
451
+ mp4_filename = f"pixio_video_{timestamp}_{prompt_hash}.mp4"
452
+ mp4_path = os.path.join("downloads", mp4_filename)
453
+
454
+ final_mp4 = frames_to_mp4_file(all_frames_for_download, mp4_path, fps)
455
+ if final_mp4:
456
+ CURRENT_DOWNLOAD_PATH = final_mp4
457
+ print(f"βœ… Final MP4 created: {final_mp4}")
458
+ else:
459
+ print("❌ Failed to create final MP4")
460
+ CURRENT_DOWNLOAD_PATH = None
461
+
462
+ except Exception as e:
463
+ print(f"❌ Error creating final MP4: {e}")
464
+ CURRENT_DOWNLOAD_PATH = None
465
+
466
+ # Final completion status with download info
467
+ download_info = "πŸ“₯ Download ready!" if CURRENT_DOWNLOAD_PATH else "❌ Download failed"
468
  final_status_html = (
469
  f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
470
  f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
 
476
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
477
  f" </p>"
478
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
479
+ f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: HLS/H.264 β€’ {download_info}"
480
  f" </p>"
481
  f" </div>"
482
  f"</div>"
483
  )
484
+ yield None, final_status_html
485
+ print(f"βœ… HLS streaming complete! {total_frames_yielded} frames")
486
+
487
+ def download_video():
488
+ """Return the current download file path."""
489
+ if CURRENT_DOWNLOAD_PATH and os.path.exists(CURRENT_DOWNLOAD_PATH):
490
+ return CURRENT_DOWNLOAD_PATH
491
+ return None
492
 
493
  # --- Gradio UI Layout ---
494
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
495
  gr.Markdown("# πŸš€ Pixio Streaming Video Generation")
496
+ gr.Markdown("Real-time video generation with distilled Wan2-1.3B [[Model]](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B), [[Project page]](https://pixio.myapps.ai), [[Paper]](https://arxiv.org/abs/2412.09738)")
497
 
498
  with gr.Row():
499
  with gr.Column(scale=2):
500
  with gr.Group():
501
  prompt = gr.Textbox(
502
  label="Prompt",
503
+ placeholder="A close-up shot of a ceramic teacup slowly pouring water into a glass mug.",
504
  lines=4,
505
+ value="A close-up shot of a ceramic teacup slowly pouring water into a glass mug."
506
  )
507
  enhance_button = gr.Button("✨ Enhance Prompt", variant="secondary")
508
 
 
558
  label="Generation Status"
559
  )
560
 
561
+ # Download button that appears after completion
562
+ with gr.Row():
563
+ download_btn = gr.DownloadButton(
564
+ label="πŸ“₯ Download MP4 Video",
565
+ value=download_video,
566
+ variant="secondary",
567
+ size="lg"
568
+ )
569
 
570
+ # Connect the streaming function
571
  start_btn.click(
572
  fn=video_generation_handler_streaming,
573
  inputs=[prompt, seed, fps],
574
+ outputs=[streaming_video, status_display]
575
  )
576
 
577
  enhance_button.click(
 
589
  os.makedirs("downloads", exist_ok=True)
590
 
591
  print("πŸš€ Starting Self-Forcing Streaming Demo")
592
+ print(f"πŸ“ Temporary files: gradio_tmp/")
593
+ print(f"πŸ“₯ Download files: downloads/")
594
+ print(f"🎯 Streaming: HLS (.m3u8 + .ts segments)")
595
+ print(f"πŸ“± Download: MP4 (imageio)")
596
 
597
  demo.queue().launch(
598
  server_name=args.host,
599
  server_port=args.port,
600
  share=args.share,
601
  show_error=True,
602
+ max_threads=40
 
603
  )
604
 
605
  # import subprocess