tsi-org commited on
Commit
13cd217
Β·
verified Β·
1 Parent(s): a12d04c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -70
app.py CHANGED
@@ -148,7 +148,7 @@ APP_STATE = {
148
  "current_vae_decoder": None,
149
  }
150
 
151
- # Global variable to store frames for download
152
  DOWNLOAD_FRAMES = []
153
 
154
  def frames_to_ts_file(frames, filepath, fps = 15):
@@ -201,59 +201,22 @@ def frames_to_ts_file(frames, filepath, fps = 15):
201
 
202
  return filepath
203
 
204
- def create_hls_playlist(ts_files, playlist_dir, fps=15):
205
- """
206
- Create HLS playlist (.m3u8) file for streaming.
207
- """
208
- playlist_path = os.path.join(playlist_dir, "playlist.m3u8")
209
- segment_duration = 2.0 # Each segment duration in seconds
210
-
211
- playlist_content = [
212
- "#EXTM3U",
213
- "#EXT-X-VERSION:3",
214
- f"#EXT-X-TARGETDURATION:{int(segment_duration) + 1}",
215
- "#EXT-X-MEDIA-SEQUENCE:0",
216
- "#EXT-X-PLAYLIST-TYPE:VOD"
217
- ]
218
-
219
- for ts_file in ts_files:
220
- ts_filename = os.path.basename(ts_file)
221
- playlist_content.extend([
222
- f"#EXTINF:{segment_duration:.1f},",
223
- ts_filename
224
- ])
225
-
226
- playlist_content.append("#EXT-X-ENDLIST")
227
-
228
- with open(playlist_path, 'w') as f:
229
- f.write('\n'.join(playlist_content))
230
-
231
- return playlist_path
232
-
233
- def create_mp4_download():
234
- """Create MP4 file from stored frames for download."""
235
  global DOWNLOAD_FRAMES
236
-
237
  if not DOWNLOAD_FRAMES:
238
  return None
239
-
240
  try:
241
  os.makedirs("downloads", exist_ok=True)
242
-
243
  timestamp = int(time.time())
244
- mp4_filename = f"pixio_video_{timestamp}.mp4"
245
- mp4_path = os.path.join("downloads", mp4_filename)
246
-
247
- # Use imageio to create MP4
248
  with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
249
  for frame in DOWNLOAD_FRAMES:
250
  writer.append_data(frame)
251
-
252
- print(f"βœ… MP4 created for download: {mp4_path}")
253
  return mp4_path
254
-
255
  except Exception as e:
256
- print(f"❌ Error creating MP4: {e}")
257
  return None
258
 
259
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
@@ -317,20 +280,16 @@ pipeline.to(dtype=torch.float16).to(gpu)
317
  @spaces.GPU
318
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
319
  """
320
- Generator function that creates HLS stream and stores frames for download.
 
321
  """
322
  global DOWNLOAD_FRAMES
323
- DOWNLOAD_FRAMES = [] # Reset frames for new generation
324
 
325
  if seed == -1:
326
  seed = random.randint(0, 2**32 - 1)
327
 
328
- print(f"🎬 Starting HLS streaming: '{prompt}', seed: {seed}")
329
-
330
- # Create unique session directory for HLS files
331
- session_id = str(uuid.uuid4())[:8]
332
- session_dir = os.path.join("gradio_tmp", f"session_{session_id}")
333
- os.makedirs(session_dir, exist_ok=True)
334
 
335
  # Setup
336
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -351,7 +310,9 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
351
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
352
 
353
  total_frames_yielded = 0
354
- ts_files = [] # Store TS file paths for HLS playlist
 
 
355
 
356
  # Generation loop
357
  for idx, current_num_frames in enumerate(all_num_frames):
@@ -415,7 +376,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
415
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
416
 
417
  all_frames_from_block.append(frame_np)
418
- DOWNLOAD_FRAMES.append(frame_np) # Store for download
419
  total_frames_yielded += 1
420
 
421
  # Yield status update for each frame (cute tracking!)
@@ -441,25 +402,25 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
441
  # Yield None for video but update status (frame-by-frame tracking)
442
  yield None, frame_status_html
443
 
444
- # Create TS segment for this block
445
  if all_frames_from_block:
446
  print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
447
 
448
  try:
449
- ts_filename = f"segment_{idx:04d}.ts"
450
- ts_path = os.path.join(session_dir, ts_filename)
 
451
 
452
  frames_to_ts_file(all_frames_from_block, ts_path, fps)
453
- ts_files.append(ts_path)
454
 
455
- # Create/update HLS playlist
456
- playlist_path = create_hls_playlist(ts_files, session_dir, fps)
457
 
458
- # Yield the HLS playlist for streaming
459
- yield playlist_path, gr.update()
460
 
461
  except Exception as e:
462
- print(f"⚠️ Error creating HLS segment {idx}: {e}")
463
  import traceback
464
  traceback.print_exc()
465
 
@@ -477,13 +438,13 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15):
477
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
478
  f" </p>"
479
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
480
- f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: HLS/H.264 β€’ πŸ“₯ Download ready!"
481
  f" </p>"
482
  f" </div>"
483
  f"</div>"
484
  )
485
  yield None, final_status_html
486
- print(f"βœ… HLS streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
487
 
488
  # --- Gradio UI Layout ---
489
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
@@ -553,10 +514,10 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
553
  label="Generation Status"
554
  )
555
 
556
- # Download button
557
  download_btn = gr.DownloadButton(
558
  label="πŸ“₯ Download MP4",
559
- value=create_mp4_download,
560
  variant="secondary"
561
  )
562
 
@@ -579,13 +540,13 @@ if __name__ == "__main__":
579
  import shutil
580
  shutil.rmtree("gradio_tmp")
581
  os.makedirs("gradio_tmp", exist_ok=True)
582
- os.makedirs("downloads", exist_ok=True)
583
 
584
  print("πŸš€ Starting Self-Forcing Streaming Demo")
585
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
586
- print(f"πŸ“₯ Download files will be stored in: downloads/")
587
- print(f"🎯 Streaming: HLS (.m3u8 + .ts segments)")
588
- print(f"πŸ“± Download: MP4 (imageio)")
589
 
590
  demo.queue().launch(
591
  server_name=args.host,
 
148
  "current_vae_decoder": None,
149
  }
150
 
151
+ # ONLY ADDITION: Store frames for download
152
  DOWNLOAD_FRAMES = []
153
 
154
  def frames_to_ts_file(frames, filepath, fps = 15):
 
201
 
202
  return filepath
203
 
204
+ # ONLY ADDITION: Download function
205
+ def create_download_mp4():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  global DOWNLOAD_FRAMES
 
207
  if not DOWNLOAD_FRAMES:
208
  return None
 
209
  try:
210
  os.makedirs("downloads", exist_ok=True)
 
211
  timestamp = int(time.time())
212
+ mp4_path = f"downloads/video_{timestamp}.mp4"
 
 
 
213
  with imageio.get_writer(mp4_path, fps=args.fps, codec='libx264', quality=8) as writer:
214
  for frame in DOWNLOAD_FRAMES:
215
  writer.append_data(frame)
216
+ print(f"βœ… Download MP4 created: {mp4_path}")
 
217
  return mp4_path
 
218
  except Exception as e:
219
+ print(f"❌ Download error: {e}")
220
  return None
221
 
222
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
 
280
  @spaces.GPU
281
  def video_generation_handler_streaming(prompt, seed=42, fps=15):
282
  """
283
+ Generator function that yields .ts video chunks using PyAV for streaming.
284
+ Now optimized for block-based processing.
285
  """
286
  global DOWNLOAD_FRAMES
287
+ DOWNLOAD_FRAMES = [] # ONLY ADDITION: Reset frames
288
 
289
  if seed == -1:
290
  seed = random.randint(0, 2**32 - 1)
291
 
292
+ print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
 
 
 
 
 
293
 
294
  # Setup
295
  conditional_dict = text_encoder(text_prompts=[prompt])
 
310
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
311
 
312
  total_frames_yielded = 0
313
+
314
+ # Ensure temp directory exists
315
+ os.makedirs("gradio_tmp", exist_ok=True)
316
 
317
  # Generation loop
318
  for idx, current_num_frames in enumerate(all_num_frames):
 
376
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
377
 
378
  all_frames_from_block.append(frame_np)
379
+ DOWNLOAD_FRAMES.append(frame_np) # ONLY ADDITION: Store for download
380
  total_frames_yielded += 1
381
 
382
  # Yield status update for each frame (cute tracking!)
 
402
  # Yield None for video but update status (frame-by-frame tracking)
403
  yield None, frame_status_html
404
 
405
+ # Encode entire block as one chunk immediately
406
  if all_frames_from_block:
407
  print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
408
 
409
  try:
410
+ chunk_uuid = str(uuid.uuid4())[:8]
411
+ ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
412
+ ts_path = os.path.join("gradio_tmp", ts_filename)
413
 
414
  frames_to_ts_file(all_frames_from_block, ts_path, fps)
 
415
 
416
+ # Calculate final progress for this block
417
+ total_progress = (idx + 1) / num_blocks * 100
418
 
419
+ # Yield the actual video chunk
420
+ yield ts_path, gr.update()
421
 
422
  except Exception as e:
423
+ print(f"⚠️ Error encoding block {idx}: {e}")
424
  import traceback
425
  traceback.print_exc()
426
 
 
438
  f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
439
  f" </p>"
440
  f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
441
+ f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264 β€’ πŸ“₯ Download ready!"
442
  f" </p>"
443
  f" </div>"
444
  f"</div>"
445
  )
446
  yield None, final_status_html
447
+ print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
448
 
449
  # --- Gradio UI Layout ---
450
  with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
 
514
  label="Generation Status"
515
  )
516
 
517
+ # ONLY ADDITION: Download button
518
  download_btn = gr.DownloadButton(
519
  label="πŸ“₯ Download MP4",
520
+ value=create_download_mp4,
521
  variant="secondary"
522
  )
523
 
 
540
  import shutil
541
  shutil.rmtree("gradio_tmp")
542
  os.makedirs("gradio_tmp", exist_ok=True)
543
+ os.makedirs("downloads", exist_ok=True) # ONLY ADDITION
544
 
545
  print("πŸš€ Starting Self-Forcing Streaming Demo")
546
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
547
+ print(f"πŸ“₯ Download files will be stored in: downloads/") # ONLY ADDITION
548
+ print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
549
+ print(f"⚑ GPU acceleration: {gpu}")
550
 
551
  demo.queue().launch(
552
  server_name=args.host,