tsi-org commited on
Commit
a12d04c
Β·
verified Β·
1 Parent(s): 52505ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -19
app.py CHANGED
@@ -201,6 +201,35 @@ def frames_to_ts_file(frames, filepath, fps = 15):
201
 
202
  return filepath
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  def create_mp4_download():
205
  """Create MP4 file from stored frames for download."""
206
  global DOWNLOAD_FRAMES
@@ -288,8 +317,7 @@ pipeline.to(dtype=torch.float16).to(gpu)
288
  @spaces.GPU
289
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
290
  """
291
- Generator function that yields .ts video chunks using PyAV for streaming.
292
- Now optimized for block-based processing.
293
  """
294
  global DOWNLOAD_FRAMES
295
  DOWNLOAD_FRAMES = [] # Reset frames for new generation
@@ -297,7 +325,12 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
297
  if seed == -1:
298
  seed = random.randint(0, 2**32 - 1)
299
 
300
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
 
 
 
 
 
301
 
302
  # Setup
303
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -318,9 +351,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
318
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
319
 
320
  total_frames_yielded = 0
321
-
322
- # Ensure temp directory exists
323
- os.makedirs("gradio_tmp", exist_ok=True)
324
 
325
  # Generation loop
326
  for idx, current_num_frames in enumerate(all_num_frames):
@@ -410,25 +441,25 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
410
  # Yield None for video but update status (frame-by-frame tracking)
411
  yield None, frame_status_html
412
 
413
- # Encode entire block as one chunk immediately
414
  if all_frames_from_block:
415
  print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
416
 
417
  try:
418
- chunk_uuid = str(uuid.uuid4())[:8]
419
- ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
420
- ts_path = os.path.join("gradio_tmp", ts_filename)
421
 
422
  frames_to_ts_file(all_frames_from_block, ts_path, fps)
 
423
 
424
- # Calculate final progress for this block
425
- total_progress = (idx + 1) / num_blocks * 100
426
 
427
- # Yield the actual video chunk
428
- yield ts_path, gr.update()
429
 
430
  except Exception as e:
431
- print(f"⚠️ Error encoding block {idx}: {e}")
432
  import traceback
433
  traceback.print_exc()
434
 
@@ -446,13 +477,13 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
446
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
447
  f" </p>"
448
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
449
- f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264 β€’ πŸ“₯ Download ready!"
450
  f" </p>"
451
  f" </div>"
452
  f"</div>"
453
  )
454
  yield None, final_status_html
455
- print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
456
 
457
  # --- Gradio UI Layout ---
458
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
@@ -553,8 +584,8 @@ if __name__ == "__main__":
553
  print("πŸš€ Starting Self-Forcing Streaming Demo")
554
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
555
  print(f"πŸ“₯ Download files will be stored in: downloads/")
556
- print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
557
- print(f"⚑ GPU acceleration: {gpu}")
558
 
559
  demo.queue().launch(
560
  server_name=args.host,
@@ -565,6 +596,7 @@ if __name__ == "__main__":
565
  mcp_server=True
566
  )
567
 
 
568
  # import subprocess
569
  # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
570
 
 
201
 
202
  return filepath
203
 
204
+ def create_hls_playlist(ts_files, playlist_dir, fps=15):
205
+ """
206
+ Create HLS playlist (.m3u8) file for streaming.
207
+ """
208
+ playlist_path = os.path.join(playlist_dir, "playlist.m3u8")
209
+ segment_duration = 2.0 # Each segment duration in seconds
210
+
211
+ playlist_content = [
212
+ "#EXTM3U",
213
+ "#EXT-X-VERSION:3",
214
+ f"#EXT-X-TARGETDURATION:{int(segment_duration) + 1}",
215
+ "#EXT-X-MEDIA-SEQUENCE:0",
216
+ "#EXT-X-PLAYLIST-TYPE:VOD"
217
+ ]
218
+
219
+ for ts_file in ts_files:
220
+ ts_filename = os.path.basename(ts_file)
221
+ playlist_content.extend([
222
+ f"#EXTINF:{segment_duration:.1f},",
223
+ ts_filename
224
+ ])
225
+
226
+ playlist_content.append("#EXT-X-ENDLIST")
227
+
228
+ with open(playlist_path, 'w') as f:
229
+ f.write('\n'.join(playlist_content))
230
+
231
+ return playlist_path
232
+
233
  def create_mp4_download():
234
  """Create MP4 file from stored frames for download."""
235
  global DOWNLOAD_FRAMES
 
317
  @spaces.GPU
318
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
319
  """
320
+ Generator function that creates HLS stream and stores frames for download.
 
321
  """
322
  global DOWNLOAD_FRAMES
323
  DOWNLOAD_FRAMES = [] # Reset frames for new generation
 
325
  if seed == -1:
326
  seed = random.randint(0, 2**32 - 1)
327
 
328
+ print(f"🎬 Starting HLS streaming: '{prompt}', seed: {seed}")
329
+
330
+ # Create unique session directory for HLS files
331
+ session_id = str(uuid.uuid4())[:8]
332
+ session_dir = os.path.join("gradio_tmp", f"session_{session_id}")
333
+ os.makedirs(session_dir, exist_ok=True)
334
 
335
  # Setup
336
  conditional_dict = text_encoder(text_prompts=[prompt])
 
351
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
352
 
353
  total_frames_yielded = 0
354
+ ts_files = [] # Store TS file paths for HLS playlist
 
 
355
 
356
  # Generation loop
357
  for idx, current_num_frames in enumerate(all_num_frames):
 
441
  # Yield None for video but update status (frame-by-frame tracking)
442
  yield None, frame_status_html
443
 
444
+ # Create TS segment for this block
445
  if all_frames_from_block:
446
  print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
447
 
448
  try:
449
+ ts_filename = f"segment_{idx:04d}.ts"
450
+ ts_path = os.path.join(session_dir, ts_filename)
 
451
 
452
  frames_to_ts_file(all_frames_from_block, ts_path, fps)
453
+ ts_files.append(ts_path)
454
 
455
+ # Create/update HLS playlist
456
+ playlist_path = create_hls_playlist(ts_files, session_dir, fps)
457
 
458
+ # Yield the HLS playlist for streaming
459
+ yield playlist_path, gr.update()
460
 
461
  except Exception as e:
462
+ print(f"⚠️ Error creating HLS segment {idx}: {e}")
463
  import traceback
464
  traceback.print_exc()
465
 
 
477
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
478
  f" </p>"
479
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
480
+ f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: HLS/H.264 β€’ πŸ“₯ Download ready!"
481
  f" </p>"
482
  f" </div>"
483
  f"</div>"
484
  )
485
  yield None, final_status_html
486
+ print(f"βœ… HLS streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
487
 
488
  # --- Gradio UI Layout ---
489
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
 
584
  print("πŸš€ Starting Self-Forcing Streaming Demo")
585
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
586
  print(f"πŸ“₯ Download files will be stored in: downloads/")
587
+ print(f"🎯 Streaming: HLS (.m3u8 + .ts segments)")
588
+ print(f"πŸ“± Download: MP4 (imageio)")
589
 
590
  demo.queue().launch(
591
  server_name=args.host,
 
596
  mcp_server=True
597
  )
598
 
599
+
600
  # import subprocess
601
  # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
602