akhaliq HF Staff commited on
Commit
c2e9ecc
Β·
verified Β·
1 Parent(s): 8dbf894

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -254
app.py CHANGED
@@ -10,7 +10,7 @@ import librosa
10
  import soundfile as sf
11
  import torch
12
  from pathlib import Path
13
- from typing import Iterator, Dict, Any
14
 
15
  # Clone and setup VibeVoice if not already present
16
  vibevoice_dir = Path('./VibeVoice')
@@ -87,6 +87,20 @@ from transformers import set_seed
87
  logging.set_verbosity_info()
88
  logger = logging.get_logger(__name__)
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  class VibeVoiceChat:
92
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
@@ -96,7 +110,8 @@ class VibeVoiceChat:
96
  self.inference_steps = inference_steps
97
  self.is_generating = False
98
  self.stop_generation = False
99
- self.current_streamer = None
 
100
 
101
  # Check GPU availability and CUDA version
102
  if torch.cuda.is_available():
@@ -104,10 +119,8 @@ class VibeVoiceChat:
104
  print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
105
  print(f" CUDA Version: {torch.version.cuda}")
106
  print(f" PyTorch CUDA: {torch.cuda.is_available()}")
107
- # Set memory fraction to avoid OOM
108
- torch.cuda.set_per_process_memory_fraction(0.95)
109
- # Enable TF32 for faster computation on Ampere GPUs
110
- torch.backends.cuda.matmul.allow_tf32 = True
111
  torch.backends.cudnn.allow_tf32 = True
112
  else:
113
  print("βœ— No GPU detected, using CPU (generation will be VERY slow)")
@@ -165,16 +178,13 @@ class VibeVoiceChat:
165
  load_time = time.time() - start_time
166
  print(f"βœ“ Model loaded in {load_time:.2f} seconds")
167
 
168
- # Print model device
169
  if hasattr(self.model, 'device'):
170
  print(f"Model device: {self.model.device}")
171
 
172
  def setup_voice_presets(self):
173
  """Setup voice presets from the voices directory."""
174
- # This assumes 'voices' directory is in the same location as the script
175
  voices_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voices")
176
 
177
- # Create voices directory if it doesn't exist
178
  if not os.path.exists(voices_dir):
179
  os.makedirs(voices_dir)
180
  print(f"Created voices directory at {voices_dir}")
@@ -183,19 +193,16 @@ class VibeVoiceChat:
183
  self.available_voices = {}
184
  audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
185
 
186
- # Scan for audio files
187
  for file in os.listdir(voices_dir):
188
  if file.lower().endswith(audio_extensions):
189
  name = os.path.splitext(file)[0]
190
  self.available_voices[name] = os.path.join(voices_dir, file)
191
 
192
- # Sort voices alphabetically
193
  self.available_voices = dict(sorted(self.available_voices.items()))
194
 
195
  if not self.available_voices:
196
  print(f"Warning: No voice files found in {voices_dir}")
197
  print("Using default (zero) voice samples. Add audio files to the voices directory for better results.")
198
- # Add a default "None" option
199
  self.available_voices = {"Default": None}
200
  else:
201
  print(f"Found {len(self.available_voices)} voice presets: {', '.join(self.available_voices.keys())}")
@@ -203,14 +210,13 @@ class VibeVoiceChat:
203
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
204
  """Read and preprocess audio file."""
205
  try:
206
- # Librosa is more robust for various audio formats
207
  wav, sr = librosa.load(audio_path, sr=None, mono=True)
208
  if sr != target_sr:
209
  wav = librosa.resample(y=wav, orig_sr=sr, target_sr=target_sr)
210
  return wav
211
  except Exception as e:
212
  print(f"Error reading audio {audio_path}: {e}")
213
- return np.zeros(target_sr) # Return 1 second of silence as fallback
214
 
215
  def format_script(self, message: str, num_speakers: int = 2) -> str:
216
  """Format input message into a script with speaker assignments."""
@@ -222,11 +228,9 @@ class VibeVoiceChat:
222
  if not line:
223
  continue
224
 
225
- # Check if already formatted (e.g., "Speaker 0: Hello")
226
  if line.startswith('Speaker ') and ':' in line:
227
  formatted_lines.append(line)
228
  else:
229
- # Auto-assign speakers in rotation
230
  speaker_id = i % num_speakers
231
  formatted_lines.append(f"Speaker {speaker_id}: {line}")
232
 
@@ -235,63 +239,50 @@ class VibeVoiceChat:
235
  def generate_audio_stream(
236
  self,
237
  message: str,
238
- history: list, # Keep history parameter for consistency, though not directly used for generation here
239
  voice_1: str,
240
  voice_2: str,
241
  num_speakers: int,
242
  cfg_scale: float
243
- ) -> Iterator[tuple]:
244
  """
245
- Generate audio stream from text input.
246
- Yields (sample_rate, audio_chunk_numpy) tuples as audio becomes available.
247
  """
248
  try:
249
  self.stop_generation = False
250
  self.is_generating = True
 
251
 
252
- # Validate inputs
253
  if not message.strip():
254
  self.is_generating = False
255
  yield None
256
  return
257
 
258
- # Format the script
259
  formatted_script = self.format_script(message, num_speakers)
260
- print(f"Formatted script:\n{formatted_script}")
261
- print(f"Using device: {self.device}")
262
-
263
- start_time = time.time() # Start timing for the overall generation
264
 
265
- # Select voices based on number of speakers
266
  selected_voices = []
267
  if voice_1 and voice_1 != "Default":
268
  selected_voices.append(voice_1)
269
  if num_speakers > 1 and voice_2 and voice_2 != "Default":
270
  selected_voices.append(voice_2)
271
 
272
- # Load voice samples
273
  voice_samples = []
274
- target_sr = 24000 # VibeVoice expects 24kHz
275
  for i in range(num_speakers):
276
- # Use the appropriate voice for each speaker
277
  if i < len(selected_voices):
278
  voice_name = selected_voices[i]
279
  if voice_name in self.available_voices and self.available_voices[voice_name]:
280
  audio_data = self.read_audio(self.available_voices[voice_name], target_sr=target_sr)
281
  else:
282
- audio_data = np.zeros(target_sr, dtype=np.float32) # Default silence
283
  else:
284
- # Fallback: use first voice or default if not enough unique voices selected
285
  if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
286
  audio_data = self.read_audio(self.available_voices[selected_voices[0]], target_sr=target_sr)
287
  else:
288
- audio_data = np.zeros(target_sr, dtype=np.float32) # Default silence
289
 
290
  voice_samples.append(audio_data)
291
 
292
- print(f"Loaded {len(voice_samples)} voice samples")
293
-
294
- # Process inputs
295
  inputs = self.processor(
296
  text=[formatted_script],
297
  voice_samples=[voice_samples],
@@ -300,14 +291,9 @@ class VibeVoiceChat:
300
  return_attention_mask=True,
301
  )
302
 
303
- # Move to device and ensure correct dtype
304
  if self.device == "cuda":
305
  inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
306
- print(f"βœ“ Inputs moved to GPU")
307
- if torch.cuda.is_available():
308
- print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
309
 
310
- # Create audio streamer
311
  audio_streamer = AudioStreamer(
312
  batch_size=1,
313
  stop_signal=None,
@@ -316,7 +302,6 @@ class VibeVoiceChat:
316
 
317
  self.current_streamer = audio_streamer
318
 
319
- # Start generation in a separate thread
320
  generation_thread = threading.Thread(
321
  target=self._generate_with_streamer,
322
  args=(inputs, cfg_scale, audio_streamer)
@@ -324,64 +309,79 @@ class VibeVoiceChat:
324
  generation_thread.start()
325
 
326
  # Give the generation thread a moment to start producing output
327
- time.sleep(0.5)
328
 
329
  audio_output_stream = audio_streamer.get_stream(0)
330
 
331
- total_generated_samples = 0
 
 
 
 
332
 
333
- # Stream audio chunks
334
- for audio_chunk in audio_output_stream:
335
  if self.stop_generation:
336
- audio_streamer.end() # Signal streamer to stop
337
- break # Exit the loop
338
 
339
- # Convert to numpy array (float32 is preferred by Gradio's Audio component)
340
- if torch.is_tensor(audio_chunk):
341
- if audio_chunk.dtype == torch.bfloat16:
342
- audio_chunk = audio_chunk.float()
343
- audio_np = audio_chunk.cpu().numpy().astype(np.float32)
344
  else:
345
- audio_np = np.array(audio_chunk, dtype=np.float32)
346
 
347
- # Ensure 1D audio array
348
  if len(audio_np.shape) > 1:
349
  audio_np = audio_np.squeeze()
350
 
351
- total_generated_samples += len(audio_np)
352
- # Yield the audio chunk directly for Gradio's streaming audio component
353
- yield (target_sr, audio_np)
354
-
355
- # Ensure generation thread completes/cleans up
356
- generation_thread.join(timeout=10.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
- generation_time = time.time() - start_time
359
- audio_duration = total_generated_samples / target_sr
 
 
360
 
361
- print(f"βœ“ Streaming complete:")
362
- print(f" Total time: {generation_time:.2f} seconds")
363
- print(f" Total audio duration: {audio_duration:.2f} seconds")
364
- if generation_time > 0:
365
- print(f" Real-time factor: {audio_duration/generation_time:.2f}x")
366
 
 
367
  self.current_streamer = None
368
  self.is_generating = False
369
 
370
  except Exception as e:
371
- print(f"Error in generation: {e}")
372
  import traceback
373
  traceback.print_exc()
374
  self.is_generating = False
375
  self.current_streamer = None
376
- yield None # Yield None to indicate an error or end of stream
 
377
 
378
- def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
379
  """Helper method to run generation with streamer."""
380
  try:
381
  def check_stop():
382
  return self.stop_generation
383
 
384
- # Use torch.cuda.amp for mixed precision if available
385
  if self.device == "cuda" and torch.cuda.is_available():
386
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
387
  self.model.generate(
@@ -412,23 +412,8 @@ class VibeVoiceChat:
412
  import traceback
413
  traceback.print_exc()
414
  finally:
415
- # Ensure the streamer is always ended, even if generation fails
416
  audio_streamer.end()
417
 
418
- def convert_to_16_bit_wav(self, data):
419
- """Convert audio data to 16-bit WAV format."""
420
- if torch.is_tensor(data):
421
- data = data.detach().cpu().numpy()
422
-
423
- data = np.array(data, dtype=np.float32) # Ensure float32 before scaling
424
-
425
- # Normalize to -1 to 1 if necessary
426
- if np.max(np.abs(data)) > 1.0:
427
- data = data / np.max(np.abs(data))
428
-
429
- data = (data * 32767).astype(np.int16)
430
- return data
431
-
432
  def stop_audio_generation(self):
433
  """Signal to stop the current audio generation."""
434
  if self.is_generating:
@@ -436,12 +421,11 @@ class VibeVoiceChat:
436
  self.stop_generation = True
437
  if self.current_streamer:
438
  try:
439
- # Give a brief moment for the streamer to process remaining buffers,
440
- # then force end it if needed.
441
- time.sleep(0.1)
442
  self.current_streamer.end()
443
  except Exception as e:
444
  print(f"Error ending streamer: {e}")
 
 
445
  else:
446
  print("No active generation to stop.")
447
 
@@ -449,69 +433,210 @@ class VibeVoiceChat:
449
  def create_chat_interface(chat_instance: VibeVoiceChat):
450
  """Create a simplified Gradio ChatInterface for VibeVoice with audio streaming."""
451
 
452
- # Get available voices
453
  voice_options = list(chat_instance.available_voices.keys())
454
  if not voice_options:
455
  voice_options = ["Default"]
456
 
457
  default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
458
  default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
- # Generator function to handle both chatbot updates and audio streaming
461
- def process_and_display_stream(message_text: str, history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
462
- """
463
- Processes the input message, updates the chatbot, and streams audio output.
464
- This function is a generator that yields updates for multiple Gradio components.
465
- """
466
- # Add user message to history immediately
467
  history = history or []
468
- history.append({"role": "user", "content": message_text})
469
-
470
- # Yield initial state: updated chatbot, clear text input, disable input,
471
- # make audio component visible but initially empty.
472
- yield history, gr.update(value="", interactive=False), gr.update(visible=True, value=None)
 
 
 
 
 
 
 
 
473
 
474
- # Generate audio stream using the VibeVoiceChat instance
 
 
 
475
  audio_stream_generator = chat_instance.generate_audio_stream(
476
- message_text, history, voice_1, voice_2, num_speakers, cfg_scale
477
  )
478
 
479
- generated_any_audio = False # Flag to track if any audio chunks were yielded
480
-
481
- # Iterate through audio chunks and yield for the audio component
482
  for chunk_data in audio_stream_generator:
483
  if chat_instance.stop_generation:
484
- break # Break if stop button was pressed
 
 
485
  if chunk_data is not None:
486
  generated_any_audio = True
487
- # Yield the current history (remains static during audio streaming),
488
- # keep msg input disabled, and pass the audio chunk for gr.Audio.
489
- yield history, gr.update(interactive=False), chunk_data
 
 
 
 
 
 
 
 
 
490
  else:
491
- # If chunk_data is None, it indicates an error or end of stream
 
492
  break
493
-
494
- # After audio generation is complete (or stopped/failed)
495
- # Add assistant message to chatbot and re-enable text input.
496
- if generated_any_audio and not chat_instance.stop_generation:
497
- history.append({"role": "assistant", "content": f"🎡 Audio generated successfully"})
498
- elif chat_instance.stop_generation:
499
- history.append({"role": "assistant", "content": f"🚫 Audio generation stopped"})
500
- chat_instance.stop_generation = False # Reset stop flag for the next generation
 
 
 
 
 
 
 
 
501
  else:
502
- history.append({"role": "assistant", "content": "Failed to generate audio"})
 
503
 
504
- # Final yield: updated chatbot, re-enabled input, and keep audio output visible
505
- # The gr.Audio component will retain the last streamed content.
506
- yield history, gr.update(value="", interactive=True), gr.update(visible=True)
507
-
 
 
 
 
 
 
508
 
509
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True) as interface:
510
- gr.Markdown("# πŸŽ™οΈ VibeVoice Chat - Streamed Audio\nGenerate natural dialogue audio with AI voices")
 
 
 
 
 
511
 
512
  with gr.Row():
513
- with gr.Column(scale=1):
514
- gr.Markdown("### Voice & Generation Settings")
515
 
516
  voice_1 = gr.Dropdown(
517
  choices=voice_options,
@@ -545,139 +670,37 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
545
  info="Guidance strength (higher = more adherence to text)"
546
  )
547
 
548
- with gr.Column(scale=2):
549
  chatbot = gr.Chatbot(
550
  label="Conversation",
551
- height=400,
552
  type="messages",
553
- elem_id="chatbot"
 
554
  )
555
 
556
  msg = gr.Textbox(
557
  label="Message",
558
  placeholder="Type your message or paste a script...",
559
- lines=3
560
- )
561
-
562
- # Gradio's gr.Audio component automatically handles streaming when a generator
563
- # function yields (sample_rate, numpy_array) tuples.
564
- audio_output = gr.Audio(
565
- label="Generated Audio",
566
- autoplay=True,
567
- streaming=True, # Explicitly setting streaming=True, though often inferred.
568
- visible=False # Initially hide the audio player
569
  )
570
 
571
- with gr.Row():
572
- submit_btn = gr.Button("🎡 Generate Audio", variant="primary")
573
- stop_btn = gr.Button("πŸ›‘ Stop Generation", variant="secondary")
574
- clear_btn = gr.Button("πŸ—‘οΈ Clear")
575
-
576
- # Example messages
577
- gr.Examples(
578
- examples=[
579
- "Hello! How are you doing today?",
580
- "Speaker 0: Welcome to our podcast!\nSpeaker 1: Thanks for having me!",
581
- "Tell me an interesting fact about space.",
582
- "What's your favorite type of music and why?",
583
- ],
584
- inputs=msg,
585
- label="Example Messages"
586
  )
587
-
588
- # Set up event handlers for the buttons and text input
589
- submit_btn.click(
590
- fn=process_and_display_stream,
591
- inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
592
- outputs=[chatbot, msg, audio_output],
593
- queue=True # Queue allows processing requests sequentially
594
- )
595
-
596
- # Allow submitting message by pressing Enter in the textbox
597
- msg.submit(
598
- fn=process_and_display_stream,
599
- inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
600
- outputs=[chatbot, msg, audio_output],
601
- queue=True
602
- )
603
-
604
- # Clear button functionality
605
- clear_btn.click(
606
- lambda: ([], gr.update(value="", interactive=True), gr.update(visible=False, value=None)),
607
- outputs=[chatbot, msg, audio_output]
608
- )
609
-
610
- # Stop button functionality - calls the VibeVoiceChat instance's stop method
611
- stop_btn.click(
612
- fn=chat_instance.stop_audio_generation,
613
- inputs=[],
614
- outputs=[], # Does not update any Gradio components directly
615
- queue=False # Important: A stop button should generally not be queued.
616
- )
617
-
618
- return interface
619
-
620
 
621
- def parse_args():
622
- parser = argparse.ArgumentParser(description="VibeVoice Chat Interface")
623
- parser.add_argument(
624
- "--model_path",
625
- type=str,
626
- default="microsoft/VibeVoice-1.5B",
627
- help="Path to the VibeVoice model",
628
- )
629
- parser.add_argument(
630
- "--device",
631
- type=str,
632
- default="cuda" if torch.cuda.is_available() else "cpu",
633
- help="Device for inference",
634
- )
635
- parser.add_argument(
636
- "--inference_steps",
637
- type=int,
638
- default=5,
639
- help="Number of DDPM inference steps (lower = faster, higher = better quality)",
640
- )
641
-
642
- return parser.parse_args()
643
-
644
-
645
- def main():
646
- """Main function to run the chat interface."""
647
- args = parse_args()
648
-
649
- set_seed(42)
650
-
651
- print("πŸŽ™οΈ Initializing VibeVoice Chat Interface...")
652
-
653
- # Initialize chat instance
654
- chat_instance = VibeVoiceChat(
655
- model_path=args.model_path,
656
- device=args.device,
657
- inference_steps=args.inference_steps
658
- )
659
-
660
- # Create interface
661
- interface = create_chat_interface(chat_instance)
662
-
663
- print(f"πŸš€ Launching chat interface")
664
- print(f"πŸ“ Model: {args.model_path}")
665
- print(f"πŸ’» Device: {chat_instance.device}")
666
- print(f"πŸ”’ Inference steps: {args.inference_steps}")
667
- print(f"🎭 Available voices: {len(chat_instance.available_voices)}")
668
-
669
- if chat_instance.device == "cpu":
670
- print("\n⚠️ WARNING: Running on CPU - generation will be VERY slow!")
671
- print(" For faster generation, ensure you have:")
672
- print(" 1. NVIDIA GPU with CUDA support")
673
- print(" 2. PyTorch with CUDA installed: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
674
-
675
- # Launch the interface
676
- interface.queue(max_size=10).launch(
677
- show_error=True,
678
- quiet=False,
679
- )
680
-
681
-
682
- if __name__ == "__main__":
683
- main()
 
10
  import soundfile as sf
11
  import torch
12
  from pathlib import Path
13
+ from typing import Iterator, Dict, Any, List
14
 
15
  # Clone and setup VibeVoice if not already present
16
  vibevoice_dir = Path('./VibeVoice')
 
87
  logging.set_verbosity_info()
88
  logger = logging.get_logger(__name__)
89
 
90
+ # --- Helper function for audio conversion ---
91
+ def convert_to_16_bit_wav(data: np.ndarray | torch.Tensor) -> np.ndarray:
92
+ """Convert audio data to 16-bit WAV format (numpy int16)."""
93
+ if torch.is_tensor(data):
94
+ data = data.detach().cpu().numpy()
95
+
96
+ data = np.array(data, dtype=np.float32) # Ensure float32 before scaling
97
+
98
+ # Normalize to -1 to 1 if necessary
99
+ if np.max(np.abs(data)) > 1.0:
100
+ data = data / np.max(np.abs(data))
101
+
102
+ data = (data * 32767).astype(np.int16)
103
+ return data
104
 
105
  class VibeVoiceChat:
106
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
 
110
  self.inference_steps = inference_steps
111
  self.is_generating = False
112
  self.stop_generation = False
113
+ self.current_streamer: AudioStreamer | None = None
114
+ self.complete_audio_buffer: List[np.ndarray] = [] # To store all generated audio for final download
115
 
116
  # Check GPU availability and CUDA version
117
  if torch.cuda.is_available():
 
119
  print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
120
  print(f" CUDA Version: {torch.version.cuda}")
121
  print(f" PyTorch CUDA: {torch.cuda.is_available()}")
122
+ torch.cuda.set_per_process_memory_fraction(0.95) # Set memory fraction to avoid OOM
123
+ torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for faster computation on Ampere GPUs
 
 
124
  torch.backends.cudnn.allow_tf32 = True
125
  else:
126
  print("βœ— No GPU detected, using CPU (generation will be VERY slow)")
 
178
  load_time = time.time() - start_time
179
  print(f"βœ“ Model loaded in {load_time:.2f} seconds")
180
 
 
181
  if hasattr(self.model, 'device'):
182
  print(f"Model device: {self.model.device}")
183
 
184
  def setup_voice_presets(self):
185
  """Setup voice presets from the voices directory."""
 
186
  voices_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voices")
187
 
 
188
  if not os.path.exists(voices_dir):
189
  os.makedirs(voices_dir)
190
  print(f"Created voices directory at {voices_dir}")
 
193
  self.available_voices = {}
194
  audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
195
 
 
196
  for file in os.listdir(voices_dir):
197
  if file.lower().endswith(audio_extensions):
198
  name = os.path.splitext(file)[0]
199
  self.available_voices[name] = os.path.join(voices_dir, file)
200
 
 
201
  self.available_voices = dict(sorted(self.available_voices.items()))
202
 
203
  if not self.available_voices:
204
  print(f"Warning: No voice files found in {voices_dir}")
205
  print("Using default (zero) voice samples. Add audio files to the voices directory for better results.")
 
206
  self.available_voices = {"Default": None}
207
  else:
208
  print(f"Found {len(self.available_voices)} voice presets: {', '.join(self.available_voices.keys())}")
 
210
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
211
  """Read and preprocess audio file."""
212
  try:
 
213
  wav, sr = librosa.load(audio_path, sr=None, mono=True)
214
  if sr != target_sr:
215
  wav = librosa.resample(y=wav, orig_sr=sr, target_sr=target_sr)
216
  return wav
217
  except Exception as e:
218
  print(f"Error reading audio {audio_path}: {e}")
219
+ return np.zeros(target_sr, dtype=np.float32)
220
 
221
  def format_script(self, message: str, num_speakers: int = 2) -> str:
222
  """Format input message into a script with speaker assignments."""
 
228
  if not line:
229
  continue
230
 
 
231
  if line.startswith('Speaker ') and ':' in line:
232
  formatted_lines.append(line)
233
  else:
 
234
  speaker_id = i % num_speakers
235
  formatted_lines.append(f"Speaker {speaker_id}: {line}")
236
 
 
239
  def generate_audio_stream(
240
  self,
241
  message: str,
 
242
  voice_1: str,
243
  voice_2: str,
244
  num_speakers: int,
245
  cfg_scale: float
246
+ ) -> Iterator[tuple]: # This generator yields (sample_rate, audio_chunk_numpy_int16)
247
  """
248
+ Generate audio stream from text input, implementing buffering for smoother streaming.
249
+ Yields (sample_rate, audio_chunk_numpy_int16) tuples as audio becomes available.
250
  """
251
  try:
252
  self.stop_generation = False
253
  self.is_generating = True
254
+ self.complete_audio_buffer = [] # Reset buffer for new generation
255
 
 
256
  if not message.strip():
257
  self.is_generating = False
258
  yield None
259
  return
260
 
 
261
  formatted_script = self.format_script(message, num_speakers)
 
 
 
 
262
 
 
263
  selected_voices = []
264
  if voice_1 and voice_1 != "Default":
265
  selected_voices.append(voice_1)
266
  if num_speakers > 1 and voice_2 and voice_2 != "Default":
267
  selected_voices.append(voice_2)
268
 
 
269
  voice_samples = []
270
+ target_sr = 24000
271
  for i in range(num_speakers):
 
272
  if i < len(selected_voices):
273
  voice_name = selected_voices[i]
274
  if voice_name in self.available_voices and self.available_voices[voice_name]:
275
  audio_data = self.read_audio(self.available_voices[voice_name], target_sr=target_sr)
276
  else:
277
+ audio_data = np.zeros(target_sr, dtype=np.float32)
278
  else:
 
279
  if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
280
  audio_data = self.read_audio(self.available_voices[selected_voices[0]], target_sr=target_sr)
281
  else:
282
+ audio_data = np.zeros(target_sr, dtype=np.float32)
283
 
284
  voice_samples.append(audio_data)
285
 
 
 
 
286
  inputs = self.processor(
287
  text=[formatted_script],
288
  voice_samples=[voice_samples],
 
291
  return_attention_mask=True,
292
  )
293
 
 
294
  if self.device == "cuda":
295
  inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
 
 
 
296
 
 
297
  audio_streamer = AudioStreamer(
298
  batch_size=1,
299
  stop_signal=None,
 
302
 
303
  self.current_streamer = audio_streamer
304
 
 
305
  generation_thread = threading.Thread(
306
  target=self._generate_with_streamer,
307
  args=(inputs, cfg_scale, audio_streamer)
 
309
  generation_thread.start()
310
 
311
  # Give the generation thread a moment to start producing output
312
+ time.sleep(1.0) # Increased from 0.5s for stability
313
 
314
  audio_output_stream = audio_streamer.get_stream(0)
315
 
316
+ # Buffering logic for smoother Gradio streaming
317
+ pending_chunks: List[np.ndarray] = []
318
+ min_yield_interval_seconds = 1.0 # Yield at least every 1 second
319
+ min_chunk_size_samples = target_sr * 0.5 # At least 0.5 seconds of audio per chunk yielded to Gradio
320
+ last_yield_time = time.time()
321
 
322
+ for audio_chunk_raw in audio_output_stream:
 
323
  if self.stop_generation:
324
+ audio_streamer.end()
325
+ break
326
 
327
+ # Convert raw chunk to numpy float32
328
+ if torch.is_tensor(audio_chunk_raw):
329
+ if audio_chunk_raw.dtype == torch.bfloat16:
330
+ audio_chunk_raw = audio_chunk_raw.float()
331
+ audio_np = audio_chunk_raw.cpu().numpy().astype(np.float32)
332
  else:
333
+ audio_np = np.array(audio_chunk_raw, dtype=np.float32)
334
 
 
335
  if len(audio_np.shape) > 1:
336
  audio_np = audio_np.squeeze()
337
 
338
+ # Append to complete buffer (for final download)
339
+ self.complete_audio_buffer.append(audio_np)
340
+
341
+ # Append to pending chunks for streaming to Gradio
342
+ pending_chunks.append(audio_np)
343
+ current_pending_size = sum(len(c) for c in pending_chunks)
344
+
345
+ current_time = time.time()
346
+
347
+ should_yield = False
348
+ if current_pending_size >= min_chunk_size_samples:
349
+ should_yield = True
350
+ elif (current_time - last_yield_time) >= min_yield_interval_seconds and pending_chunks:
351
+ should_yield = True
352
+
353
+ if should_yield:
354
+ combined_chunk = np.concatenate(pending_chunks)
355
+ yield (target_sr, convert_to_16_bit_wav(combined_chunk)) # Convert to int16 before yielding
356
+ pending_chunks = []
357
+ last_yield_time = current_time
358
 
359
+ # Yield any remaining chunks after the loop finishes
360
+ if pending_chunks and not self.stop_generation:
361
+ combined_chunk = np.concatenate(pending_chunks)
362
+ yield (target_sr, convert_to_16_bit_wav(combined_chunk))
363
 
364
+ generation_thread.join(timeout=10.0) # Ensure generation thread completes
 
 
 
 
365
 
366
+ # Clean up
367
  self.current_streamer = None
368
  self.is_generating = False
369
 
370
  except Exception as e:
371
+ print(f"Error in generate_audio_stream: {e}")
372
  import traceback
373
  traceback.print_exc()
374
  self.is_generating = False
375
  self.current_streamer = None
376
+ self.complete_audio_buffer = [] # Clear buffer on error
377
+ yield None
378
 
379
+ def _generate_with_streamer(self, inputs: Dict[str, Any], cfg_scale: float, audio_streamer: AudioStreamer):
380
  """Helper method to run generation with streamer."""
381
  try:
382
  def check_stop():
383
  return self.stop_generation
384
 
 
385
  if self.device == "cuda" and torch.cuda.is_available():
386
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
387
  self.model.generate(
 
412
  import traceback
413
  traceback.print_exc()
414
  finally:
 
415
  audio_streamer.end()
416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  def stop_audio_generation(self):
418
  """Signal to stop the current audio generation."""
419
  if self.is_generating:
 
421
  self.stop_generation = True
422
  if self.current_streamer:
423
  try:
 
 
 
424
  self.current_streamer.end()
425
  except Exception as e:
426
  print(f"Error ending streamer: {e}")
427
+ self.is_generating = False
428
+ self.complete_audio_buffer = []
429
  else:
430
  print("No active generation to stop.")
431
 
 
433
  def create_chat_interface(chat_instance: VibeVoiceChat):
434
  """Create a simplified Gradio ChatInterface for VibeVoice with audio streaming."""
435
 
 
436
  voice_options = list(chat_instance.available_voices.keys())
437
  if not voice_options:
438
  voice_options = ["Default"]
439
 
440
  default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
441
  default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
442
+
443
+ # Custom CSS for modern aesthetics
444
+ custom_css = """
445
+ .gradio-container {
446
+ font-family: 'Inter', sans-serif;
447
+ background: linear-gradient(135deg, #f0f2f5 0%, #e0e6ed 100%);
448
+ color: #333;
449
+ }
450
+ .main-header {
451
+ background: linear-gradient(45deg, #4A00E0 0%, #8E2DE2 100%);
452
+ padding: 20px 30px;
453
+ border-radius: 15px;
454
+ margin-bottom: 25px;
455
+ text-align: center;
456
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
457
+ }
458
+ .main-header h1 {
459
+ color: white;
460
+ font-size: 2.8em;
461
+ font-weight: 800;
462
+ margin: 0;
463
+ letter-spacing: -1px;
464
+ text-shadow: 0 3px 5px rgba(0,0,0,0.2);
465
+ }
466
+ .main-header p {
467
+ color: rgba(255,255,255,0.85);
468
+ font-size: 1.1em;
469
+ margin-top: 10px;
470
+ }
471
+ .settings-card, .generation-card {
472
+ background: rgba(255, 255, 255, 0.9);
473
+ border: 1px solid #dcdfe6;
474
+ border-radius: 12px;
475
+ padding: 20px;
476
+ box-shadow: 0 4px 15px rgba(0, 0, 0, 0.08);
477
+ transition: all 0.3s ease;
478
+ }
479
+ .settings-card:hover, .generation-card:hover {
480
+ box-shadow: 0 6px 20px rgba(0, 0, 0, 0.12);
481
+ transform: translateY(-2px);
482
+ }
483
+ .gradio-output {
484
+ border-radius: 10px;
485
+ background-color: #fcfcfc;
486
+ }
487
+ .gradio-button {
488
+ border-radius: 8px !important;
489
+ font-weight: 600;
490
+ padding: 10px 20px;
491
+ transition: all 0.2s ease-in-out;
492
+ }
493
+ .gradio-button.primary {
494
+ background: linear-gradient(45deg, #4CAF50 0%, #8BC34A 100%) !important;
495
+ color: white !important;
496
+ border: none !important;
497
+ }
498
+ .gradio-button.primary:hover {
499
+ opacity: 0.9;
500
+ transform: translateY(-1px);
501
+ }
502
+ .gradio-button.secondary {
503
+ background: linear-gradient(45deg, #FF5722 0%, #FFC107 100%) !important;
504
+ color: white !important;
505
+ border: none !important;
506
+ }
507
+ .gradio-button.secondary:hover {
508
+ opacity: 0.9;
509
+ transform: translateY(-1px);
510
+ }
511
+ .gradio-button.clear {
512
+ background: #90A4AE !important;
513
+ color: white !important;
514
+ border: none !important;
515
+ }
516
+ .gradio-button.clear:hover {
517
+ opacity: 0.9;
518
+ transform: translateY(-1px);
519
+ }
520
+ .gradio-input {
521
+ border-radius: 8px !important;
522
+ border: 1px solid #ced4da !important;
523
+ }
524
+ .gradio-label {
525
+ font-weight: 700;
526
+ color: #495057;
527
+ margin-bottom: 5px;
528
+ }
529
+ .chatbot {
530
+ border: 1px solid #e0e0e0 !important;
531
+ border-radius: 10px !important;
532
+ box-shadow: 0 2px 10px rgba(0,0,0,0.05);
533
+ }
534
+ .log-output {
535
+ font-family: 'JetBrains Mono', monospace;
536
+ background-color: #f8f9fa !important;
537
+ border-radius: 8px !important;
538
+ border: 1px solid #e9ecef !important;
539
+ color: #495057 !important;
540
+ min-height: 80px;
541
+ }
542
+ .audio-output {
543
+ border-radius: 10px !important;
544
+ border: 1px solid #e0e0e0 !important;
545
+ background-color: #f8f9fa !important;
546
+ }
547
+ """
548
 
549
+ # Gradio handler function that coordinates UI updates
550
+ def process_and_display_stream(message_text: str, history: List[Dict[str, str]], voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
 
 
 
 
 
551
  history = history or []
552
+ user_message_entry = {"role": "user", "content": message_text}
553
+
554
+ # Initial state: user message added, text input disabled, buttons updated, audio cleared/hidden
555
+ # This yield ensures immediate UI feedback
556
+ yield (
557
+ history + [user_message_entry], # Add user message to chatbot immediately
558
+ gr.update(value="", interactive=False), # Clear text input and disable
559
+ gr.update(value=None, visible=True), # Clear streaming audio, make visible
560
+ gr.update(value=None, visible=False), # Clear complete audio, hide
561
+ "πŸŽ™οΈ Starting audio generation...", # Initial log message
562
+ gr.update(interactive=False, value="Generating..."), # Disable submit button
563
+ gr.update(visible=True) # Show stop button
564
+ )
565
 
566
+ log_message = ""
567
+ generated_any_audio = False
568
+
569
+ # Call the chat_instance's audio generator
570
  audio_stream_generator = chat_instance.generate_audio_stream(
571
+ message_text, voice_1, voice_2, num_speakers, cfg_scale
572
  )
573
 
574
+ # Loop through the streaming audio chunks
 
 
575
  for chunk_data in audio_stream_generator:
576
  if chat_instance.stop_generation:
577
+ log_message = "πŸ›‘ Audio generation stopped."
578
+ break
579
+
580
  if chunk_data is not None:
581
  generated_any_audio = True
582
+ log_message = "🎡 Streaming audio..."
583
+ # Yield current chunk to streaming audio component
584
+ # Other components remain static during streaming
585
+ yield (
586
+ history + [user_message_entry], # Chatbot state
587
+ gr.update(interactive=False), # Text input disabled
588
+ chunk_data, # Streaming audio chunk
589
+ gr.update(visible=False), # Complete audio hidden
590
+ log_message, # Log update
591
+ gr.update(interactive=False, value="Generating..."), # Submit button still disabled
592
+ gr.update(visible=True) # Stop button still visible
593
+ )
594
  else:
595
+ # None indicates an error or unexpected end from the generator
596
+ log_message = "❌ Error during audio generation."
597
  break
598
+
599
+ # After generation (or stop/error), prepare final updates
600
+ final_chatbot_history = history + [user_message_entry]
601
+ final_streaming_audio_update = gr.update(value=None, visible=False) # Hide streaming audio
602
+ final_complete_audio_update = gr.update(value=None, visible=False) # Default to hidden
603
+
604
+ if chat_instance.stop_generation:
605
+ final_chatbot_history.append({"role": "assistant", "content": "🚫 Audio generation stopped."})
606
+ log_message = "πŸ›‘ Generation stopped by user."
607
+ chat_instance.stop_generation = False # Reset flag for next run
608
+ elif generated_any_audio and chat_instance.complete_audio_buffer:
609
+ # Concatenate all collected audio chunks for the final downloadable audio
610
+ complete_audio_data_np = np.concatenate(chat_instance.complete_audio_buffer)
611
+ final_complete_audio_update = gr.update(value=(24000, convert_to_16_bit_wav(complete_audio_data_np)), visible=True)
612
+ final_chatbot_history.append({"role": "assistant", "content": "βœ… Audio generated successfully! Listen below and download."})
613
+ log_message = "✨ Generation complete! See 'Complete Audio' below."
614
  else:
615
+ final_chatbot_history.append({"role": "assistant", "content": "❌ Failed to generate audio."})
616
+ log_message = "❌ Generation failed or no audio produced."
617
 
618
+ # Final yield to update all components after streaming
619
+ yield (
620
+ final_chatbot_history,
621
+ gr.update(value="", interactive=True), # Re-enable text input
622
+ final_streaming_audio_update,
623
+ final_complete_audio_update,
624
+ log_message,
625
+ gr.update(interactive=True, value="🎡 Generate Audio"), # Re-enable submit button
626
+ gr.update(visible=False) # Hide stop button
627
+ )
628
 
629
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True, css=custom_css) as interface:
630
+ gr.HTML("""
631
+ <div class="main-header">
632
+ <h1>πŸŽ™οΈ VibeVoice Chat - Streamed Audio</h1>
633
+ <p>Generate natural dialogue audio with AI voices</p>
634
+ </div>
635
+ """)
636
 
637
  with gr.Row():
638
+ with gr.Column(scale=1, elem_classes="settings-card"):
639
+ gr.Markdown("### πŸŽ›οΈ **Voice & Generation Settings**")
640
 
641
  voice_1 = gr.Dropdown(
642
  choices=voice_options,
 
670
  info="Guidance strength (higher = more adherence to text)"
671
  )
672
 
673
+ with gr.Column(scale=2, elem_classes="generation-card"):
674
  chatbot = gr.Chatbot(
675
  label="Conversation",
676
+ height=300, # Adjusted height
677
  type="messages",
678
+ elem_id="chatbot",
679
+ elem_classes="chatbot"
680
  )
681
 
682
  msg = gr.Textbox(
683
  label="Message",
684
  placeholder="Type your message or paste a script...",
685
+ lines=3,
686
+ elem_classes="gradio-input"
 
 
 
 
 
 
 
 
687
  )
688
 
689
+ # Log output for generation status
690
+ log_output = gr.Textbox(
691
+ label="Generation Log",
692
+ lines=2,
693
+ max_lines=5,
694
+ interactive=False,
695
+ value="Ready to generate audio.",
696
+ elem_classes="log-output"
 
 
 
 
 
 
 
697
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
 
699
+ # Streaming audio component
700
+ audio_output = gr.Audio(
701
+ label="Streaming Audio (Real-time Playback)",
702
+ type="numpy", # Expects (sr, np_array)
703
+ streaming=True,
704
+ autoplay=True,
705
+ visible=True, # Start visible but empty
706
+ show_download_but