akhaliq HF Staff commited on
Commit
0fa6dba
·
verified ·
1 Parent(s): e35db11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -255
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import argparse
2
  import os
3
  import tempfile
@@ -88,21 +92,6 @@ logging.set_verbosity_info()
88
  logger = logging.get_logger(__name__)
89
 
90
 
91
- def convert_to_16_bit_wav(data):
92
- """Convert audio data to 16-bit WAV format."""
93
- if torch.is_tensor(data):
94
- data = data.detach().cpu().numpy()
95
-
96
- data = np.array(data, dtype=np.float32)
97
-
98
- # Normalize to -1 to 1 if necessary
99
- if np.max(np.abs(data)) > 1.0:
100
- data = data / np.max(np.abs(data))
101
-
102
- data = (data * 32767).astype(np.int16)
103
- return data
104
-
105
-
106
  class VibeVoiceChat:
107
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
108
  """Initialize the VibeVoice chat model."""
@@ -186,8 +175,7 @@ class VibeVoiceChat:
186
 
187
  def setup_voice_presets(self):
188
  """Setup voice presets from the voices directory."""
189
- # This assumes 'voices' directory is in the same location as the script
190
- voices_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voices")
191
 
192
  # Create voices directory if it doesn't exist
193
  if not os.path.exists(voices_dir):
@@ -218,14 +206,15 @@ class VibeVoiceChat:
218
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
219
  """Read and preprocess audio file."""
220
  try:
221
- # Librosa is more robust for various audio formats
222
- wav, sr = librosa.load(audio_path, sr=None, mono=True)
 
223
  if sr != target_sr:
224
- wav = librosa.resample(y=wav, orig_sr=sr, target_sr=target_sr)
225
  return wav
226
  except Exception as e:
227
  print(f"Error reading audio {audio_path}: {e}")
228
- return np.zeros(target_sr) # Return 1 second of silence as fallback
229
 
230
  def format_script(self, message: str, num_speakers: int = 2) -> str:
231
  """Format input message into a script with speaker assignments."""
@@ -237,7 +226,7 @@ class VibeVoiceChat:
237
  if not line:
238
  continue
239
 
240
- # Check if already formatted (e.g., "Speaker 0: Hello")
241
  if line.startswith('Speaker ') and ':' in line:
242
  formatted_lines.append(line)
243
  else:
@@ -256,18 +245,14 @@ class VibeVoiceChat:
256
  num_speakers: int,
257
  cfg_scale: float
258
  ) -> Iterator[tuple]:
259
- """
260
- Generate audio stream from text input.
261
- Yields (sample_rate, audio_chunk_numpy) tuples as audio becomes available.
262
- """
263
  try:
264
  self.stop_generation = False
265
  self.is_generating = True
266
 
267
  # Validate inputs
268
  if not message.strip():
269
- self.is_generating = False
270
- yield None, "❌ Error: Please provide a message."
271
  return
272
 
273
  # Format the script
@@ -275,6 +260,7 @@ class VibeVoiceChat:
275
  print(f"Formatted script:\n{formatted_script}")
276
  print(f"Using device: {self.device}")
277
 
 
278
  start_time = time.time()
279
 
280
  # Select voices based on number of speakers
@@ -286,19 +272,20 @@ class VibeVoiceChat:
286
 
287
  # Load voice samples
288
  voice_samples = []
289
- target_sr = 24000
290
  for i in range(num_speakers):
 
291
  if i < len(selected_voices):
292
  voice_name = selected_voices[i]
293
  if voice_name in self.available_voices and self.available_voices[voice_name]:
294
- audio_data = self.read_audio(self.available_voices[voice_name], target_sr=target_sr)
295
  else:
296
- audio_data = np.zeros(target_sr, dtype=np.float32)
297
  else:
 
298
  if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
299
- audio_data = self.read_audio(self.available_voices[selected_voices[0]], target_sr=target_sr)
300
  else:
301
- audio_data = np.zeros(target_sr, dtype=np.float32)
302
 
303
  voice_samples.append(audio_data)
304
 
@@ -317,6 +304,9 @@ class VibeVoiceChat:
317
  if self.device == "cuda":
318
  inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
319
  print(f"✓ Inputs moved to GPU")
 
 
 
320
 
321
  # Create audio streamer
322
  audio_streamer = AudioStreamer(
@@ -327,45 +317,31 @@ class VibeVoiceChat:
327
 
328
  self.current_streamer = audio_streamer
329
 
330
- # Start generation in a separate thread
331
  generation_thread = threading.Thread(
332
  target=self._generate_with_streamer,
333
  args=(inputs, cfg_scale, audio_streamer)
334
  )
335
  generation_thread.start()
336
 
337
- # Wait for generation to start
338
  time.sleep(1)
339
 
340
- # Check for stop signal
341
- if self.stop_generation:
342
- audio_streamer.end()
343
- generation_thread.join(timeout=5.0)
344
- self.is_generating = False
345
- yield None, "🛑 Generation stopped by user"
346
- return
347
-
348
- # Get the audio stream
349
- audio_output_stream = audio_streamer.get_stream(0)
350
 
351
  all_audio_chunks = []
352
- pending_chunks = []
353
  chunk_count = 0
354
- last_yield_time = time.time()
355
- min_yield_interval = 15
356
- min_chunk_size = target_sr * 30
357
-
358
- has_yielded_audio = False
359
- has_received_chunks = False
360
 
361
- for audio_chunk in audio_output_stream:
362
  if self.stop_generation:
363
  audio_streamer.end()
364
  break
365
-
366
  chunk_count += 1
367
- has_received_chunks = True
368
 
 
369
  if torch.is_tensor(audio_chunk):
370
  if audio_chunk.dtype == torch.bfloat16:
371
  audio_chunk = audio_chunk.float()
@@ -373,87 +349,43 @@ class VibeVoiceChat:
373
  else:
374
  audio_np = np.array(audio_chunk, dtype=np.float32)
375
 
 
376
  if len(audio_np.shape) > 1:
377
  audio_np = audio_np.squeeze()
378
 
379
- audio_16bit = convert_to_16_bit_wav(audio_np)
 
380
  all_audio_chunks.append(audio_16bit)
381
- pending_chunks.append(audio_16bit)
382
-
383
- pending_audio_size = sum(len(chunk) for chunk in pending_chunks)
384
- current_time = time.time()
385
- time_since_last_yield = current_time - last_yield_time
386
 
387
- should_yield = False
388
- if not has_yielded_audio and pending_audio_size >= min_chunk_size:
389
- should_yield = True
390
- has_yielded_audio = True
391
- elif has_yielded_audio and (pending_audio_size >= min_chunk_size or time_since_last_yield >= min_yield_interval):
392
- should_yield = True
393
-
394
- if should_yield and pending_chunks:
395
- new_audio = np.concatenate(pending_chunks)
396
- total_duration = sum(len(chunk) for chunk in all_audio_chunks) / target_sr
397
-
398
- log_update = f"🎵 Streaming: {total_duration:.1f}s generated (chunk {chunk_count})"
399
- yield (target_sr, new_audio), log_update
400
-
401
- pending_chunks = []
402
- last_yield_time = current_time
403
-
404
- # Yield any remaining chunks
405
- if pending_chunks:
406
- final_new_audio = np.concatenate(pending_chunks)
407
- total_duration = sum(len(chunk) for chunk in all_audio_chunks) / target_sr
408
- log_update = f"🎵 Streaming final chunk: {total_duration:.1f}s total"
409
- yield (target_sr, final_new_audio), log_update
410
- has_yielded_audio = True
411
 
412
  # Wait for generation to complete
413
  generation_thread.join(timeout=5.0)
414
 
415
- if generation_thread.is_alive():
416
- print("Warning: Generation thread did not complete within timeout")
417
- audio_streamer.end()
418
- generation_thread.join(timeout=5.0)
419
-
 
 
 
 
 
 
420
  self.current_streamer = None
421
  self.is_generating = False
422
 
423
- generation_time = time.time() - start_time
424
-
425
- if self.stop_generation:
426
- yield None, "🛑 Generation stopped by user"
427
- return
428
-
429
- if not has_received_chunks:
430
- yield None, f"❌ Error: No audio chunks were received. Generation time: {generation_time:.2f}s"
431
- return
432
-
433
- if not has_yielded_audio:
434
- yield None, f"❌ Error: Audio was generated but not streamed. Chunk count: {chunk_count}"
435
- return
436
-
437
- if all_audio_chunks:
438
- complete_audio = np.concatenate(all_audio_chunks)
439
- final_duration = len(complete_audio) / target_sr
440
-
441
- final_log = f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
442
- final_log += f"🎵 Final audio duration: {final_duration:.2f} seconds\n"
443
- final_log += f"📊 Total chunks: {chunk_count}\n"
444
- final_log += "✨ Generation successful!"
445
-
446
- yield None, final_log
447
- else:
448
- yield None, "❌ No audio was generated."
449
-
450
  except Exception as e:
451
  print(f"Error in generation: {e}")
452
  import traceback
453
  traceback.print_exc()
454
  self.is_generating = False
455
  self.current_streamer = None
456
- yield None, f"❌ An unexpected error occurred: {str(e)}"
457
 
458
  def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
459
  """Helper method to run generation with streamer."""
@@ -461,9 +393,10 @@ class VibeVoiceChat:
461
  def check_stop():
462
  return self.stop_generation
463
 
 
464
  if self.device == "cuda" and torch.cuda.is_available():
465
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
466
- self.model.generate(
467
  **inputs,
468
  max_new_tokens=None,
469
  cfg_scale=cfg_scale,
@@ -475,7 +408,7 @@ class VibeVoiceChat:
475
  refresh_negative=True,
476
  )
477
  else:
478
- self.model.generate(
479
  **inputs,
480
  max_new_tokens=None,
481
  cfg_scale=cfg_scale,
@@ -490,25 +423,35 @@ class VibeVoiceChat:
490
  print(f"Error in generation thread: {e}")
491
  import traceback
492
  traceback.print_exc()
493
- finally:
494
  audio_streamer.end()
495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  def stop_audio_generation(self):
497
- """Signal to stop the current audio generation."""
498
- if self.is_generating:
499
- print("🛑 Stop signal received.")
500
- self.stop_generation = True
501
- if self.current_streamer:
502
- try:
503
- time.sleep(0.1)
504
- self.current_streamer.end()
505
- except Exception as e:
506
- print(f"Error ending streamer: {e}")
507
 
508
 
509
  def create_chat_interface(chat_instance: VibeVoiceChat):
510
- """Create a Gradio ChatInterface for VibeVoice with improved streaming."""
511
 
 
512
  voice_options = list(chat_instance.available_voices.keys())
513
  if not voice_options:
514
  voice_options = ["Default"]
@@ -516,40 +459,51 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
516
  default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
517
  default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
518
 
519
- # Custom CSS for streaming UI
520
- custom_css = """
521
- .streaming-indicator {
522
- display: inline-block;
523
- width: 10px;
524
- height: 10px;
525
- background: #22c55e;
526
- border-radius: 50%;
527
- margin-right: 8px;
528
- animation: pulse 1.5s infinite;
529
- }
530
-
531
- @keyframes pulse {
532
- 0% { opacity: 1; transform: scale(1); }
533
- 50% { opacity: 0.5; transform: scale(1.1); }
534
- 100% { opacity: 1; transform: scale(1); }
535
- }
536
-
537
- .streaming-status {
538
- background: linear-gradient(135deg, #dcfce7 0%, #bbf7d0 100%);
539
- border: 1px solid rgba(34, 197, 94, 0.3);
540
- border-radius: 8px;
541
- padding: 0.75rem;
542
- margin: 0.5rem 0;
543
- text-align: center;
544
- font-size: 0.9rem;
545
- color: #166534;
546
- }
547
- """
548
-
549
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"),
550
- css=custom_css, fill_height=True) as interface:
551
 
552
- gr.Markdown("# 🎙️ VibeVoice Chat - Streamed Audio\nGenerate natural dialogue audio with AI voices")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
  with gr.Row():
555
  with gr.Column(scale=1):
@@ -601,33 +555,18 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
601
  lines=3
602
  )
603
 
604
- # Streaming status indicator
605
- streaming_status = gr.HTML(
606
- value="""
607
- <div class="streaming-status">
608
- <span class="streaming-indicator"></span>
609
- <strong>LIVE STREAMING</strong> - Audio is being generated in real-time
610
- </div>
611
- """,
612
- visible=False,
613
- elem_id="streaming-status"
614
- )
615
-
616
- # Audio output with streaming enabled
617
  audio_output = gr.Audio(
618
  label="Generated Audio",
619
- type="numpy",
620
- streaming=True,
621
  autoplay=True,
622
- show_download_button=False,
623
  visible=False
624
  )
625
 
626
  with gr.Row():
627
- submit_btn = gr.Button("🎵 Generate Audio", variant="primary")
628
- stop_btn = gr.Button("🛑 Stop Generation", variant="secondary")
629
- clear_btn = gr.Button("🗑️ Clear")
630
 
 
631
  gr.Examples(
632
  examples=[
633
  "Hello! How are you doing today?",
@@ -639,90 +578,39 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
639
  label="Example Messages"
640
  )
641
 
642
- def process_and_display_stream(message_text: str, history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
643
- """Process input and stream audio with status updates."""
 
 
644
  history = history or []
645
- history.append({"role": "user", "content": message_text})
646
 
647
- # Initial state: clear audio, show status, disable input
648
- yield history, gr.update(value="", interactive=False), None, gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
649
-
650
- # Generate audio stream
651
- for audio_chunk, log_message in chat_instance.generate_audio_stream(
652
- message_text, history, voice_1, voice_2, num_speakers, cfg_scale
653
- ):
654
- if audio_chunk is not None:
655
- # Streaming audio chunk received
656
- yield history, gr.update(interactive=False), audio_chunk, gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
657
- else:
658
- # Final status message or error
659
- if log_message and "❌" in log_message:
660
- # Error case
661
- history.append({"role": "assistant", "content": log_message})
662
- yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
663
- elif log_message and "🛑" in log_message:
664
- # Stopped case
665
- history.append({"role": "assistant", "content": log_message})
666
- yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
667
- elif log_message:
668
- # Final success message
669
- history.append({"role": "assistant", "content": "🎵 Audio generated successfully!"})
670
- yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
671
 
672
- # Final cleanup
673
- yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
674
-
675
- def stop_generation_handler():
676
- """Handle stop generation."""
677
- chat_instance.stop_audio_generation()
678
- return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
679
 
680
- def clear_audio_outputs():
681
- """Clear audio outputs."""
682
- return None
683
-
684
- # Event handlers
685
- submit_btn.click(
686
- fn=clear_audio_outputs,
687
- inputs=[],
688
- outputs=[audio_output],
689
- queue=False
690
- ).then(
691
- fn=process_and_display_stream,
692
  inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
693
- outputs=[chatbot, msg, audio_output, streaming_status, submit_btn, stop_btn],
694
  queue=True
695
  )
696
 
697
  msg.submit(
698
- fn=clear_audio_outputs,
699
- inputs=[],
700
- outputs=[audio_output],
701
- queue=False
702
- ).then(
703
- fn=process_and_display_stream,
704
  inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
705
- outputs=[chatbot, msg, audio_output, streaming_status, submit_btn, stop_btn],
706
  queue=True
707
  )
708
 
709
- stop_btn.click(
710
- fn=stop_generation_handler,
711
- inputs=[],
712
- outputs=[streaming_status, submit_btn, stop_btn],
713
- queue=False
714
- ).then(
715
- fn=lambda: None,
716
- inputs=[],
717
- outputs=[audio_output],
718
- queue=False
719
- )
720
-
721
- clear_btn.click(
722
- lambda: ([], gr.update(value="", interactive=True), gr.update(visible=False, value=None),
723
- gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)),
724
- outputs=[chatbot, msg, audio_output, streaming_status, submit_btn, stop_btn]
725
- )
726
 
727
  return interface
728
 
@@ -777,6 +665,9 @@ def main():
777
 
778
  if chat_instance.device == "cpu":
779
  print("\n⚠️ WARNING: Running on CPU - generation will be VERY slow!")
 
 
 
780
 
781
  # Launch the interface
782
  interface.queue(max_size=10).launch(
 
1
+ """
2
+ VibeVoice Simple Chat Interface - Streamlined Audio Generation Demo
3
+ """
4
+
5
  import argparse
6
  import os
7
  import tempfile
 
92
  logger = logging.get_logger(__name__)
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  class VibeVoiceChat:
96
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
97
  """Initialize the VibeVoice chat model."""
 
175
 
176
  def setup_voice_presets(self):
177
  """Setup voice presets from the voices directory."""
178
+ voices_dir = os.path.join(os.path.dirname(__file__), "voices")
 
179
 
180
  # Create voices directory if it doesn't exist
181
  if not os.path.exists(voices_dir):
 
206
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
207
  """Read and preprocess audio file."""
208
  try:
209
+ wav, sr = sf.read(audio_path)
210
+ if len(wav.shape) > 1:
211
+ wav = np.mean(wav, axis=1)
212
  if sr != target_sr:
213
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
214
  return wav
215
  except Exception as e:
216
  print(f"Error reading audio {audio_path}: {e}")
217
+ return np.zeros(24000) # Return 1 second of silence as fallback
218
 
219
  def format_script(self, message: str, num_speakers: int = 2) -> str:
220
  """Format input message into a script with speaker assignments."""
 
226
  if not line:
227
  continue
228
 
229
+ # Check if already formatted
230
  if line.startswith('Speaker ') and ':' in line:
231
  formatted_lines.append(line)
232
  else:
 
245
  num_speakers: int,
246
  cfg_scale: float
247
  ) -> Iterator[tuple]:
248
+ """Generate audio stream from text input."""
 
 
 
249
  try:
250
  self.stop_generation = False
251
  self.is_generating = True
252
 
253
  # Validate inputs
254
  if not message.strip():
255
+ yield None
 
256
  return
257
 
258
  # Format the script
 
260
  print(f"Formatted script:\n{formatted_script}")
261
  print(f"Using device: {self.device}")
262
 
263
+ # Start timing
264
  start_time = time.time()
265
 
266
  # Select voices based on number of speakers
 
272
 
273
  # Load voice samples
274
  voice_samples = []
 
275
  for i in range(num_speakers):
276
+ # Use the appropriate voice for each speaker
277
  if i < len(selected_voices):
278
  voice_name = selected_voices[i]
279
  if voice_name in self.available_voices and self.available_voices[voice_name]:
280
+ audio_data = self.read_audio(self.available_voices[voice_name])
281
  else:
282
+ audio_data = np.zeros(24000) # Default silence
283
  else:
284
+ # Use first voice or default if not enough voices selected
285
  if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
286
+ audio_data = self.read_audio(self.available_voices[selected_voices[0]])
287
  else:
288
+ audio_data = np.zeros(24000) # Default silence
289
 
290
  voice_samples.append(audio_data)
291
 
 
304
  if self.device == "cuda":
305
  inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
306
  print(f"✓ Inputs moved to GPU")
307
+ # Check GPU memory
308
+ if torch.cuda.is_available():
309
+ print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
310
 
311
  # Create audio streamer
312
  audio_streamer = AudioStreamer(
 
317
 
318
  self.current_streamer = audio_streamer
319
 
320
+ # Start generation in separate thread
321
  generation_thread = threading.Thread(
322
  target=self._generate_with_streamer,
323
  args=(inputs, cfg_scale, audio_streamer)
324
  )
325
  generation_thread.start()
326
 
327
+ # Wait briefly for generation to start
328
  time.sleep(1)
329
 
330
+ # Stream audio chunks
331
+ sample_rate = 24000
332
+ audio_stream = audio_streamer.get_stream(0)
 
 
 
 
 
 
 
333
 
334
  all_audio_chunks = []
 
335
  chunk_count = 0
 
 
 
 
 
 
336
 
337
+ for audio_chunk in audio_stream:
338
  if self.stop_generation:
339
  audio_streamer.end()
340
  break
341
+
342
  chunk_count += 1
 
343
 
344
+ # Convert to numpy
345
  if torch.is_tensor(audio_chunk):
346
  if audio_chunk.dtype == torch.bfloat16:
347
  audio_chunk = audio_chunk.float()
 
349
  else:
350
  audio_np = np.array(audio_chunk, dtype=np.float32)
351
 
352
+ # Ensure 1D
353
  if len(audio_np.shape) > 1:
354
  audio_np = audio_np.squeeze()
355
 
356
+ # Convert to 16-bit
357
+ audio_16bit = self.convert_to_16_bit_wav(audio_np)
358
  all_audio_chunks.append(audio_16bit)
 
 
 
 
 
359
 
360
+ # Yield accumulated audio
361
+ if all_audio_chunks:
362
+ complete_audio = np.concatenate(all_audio_chunks)
363
+ yield (sample_rate, complete_audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  # Wait for generation to complete
366
  generation_thread.join(timeout=5.0)
367
 
368
+ # Final yield with complete audio
369
+ if all_audio_chunks:
370
+ complete_audio = np.concatenate(all_audio_chunks)
371
+ generation_time = time.time() - start_time
372
+ audio_duration = len(complete_audio) / sample_rate
373
+ print(f"✓ Generation complete:")
374
+ print(f" Time taken: {generation_time:.2f} seconds")
375
+ print(f" Audio duration: {audio_duration:.2f} seconds")
376
+ print(f" Real-time factor: {audio_duration/generation_time:.2f}x")
377
+ yield (sample_rate, complete_audio)
378
+
379
  self.current_streamer = None
380
  self.is_generating = False
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  except Exception as e:
383
  print(f"Error in generation: {e}")
384
  import traceback
385
  traceback.print_exc()
386
  self.is_generating = False
387
  self.current_streamer = None
388
+ yield None
389
 
390
  def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
391
  """Helper method to run generation with streamer."""
 
393
  def check_stop():
394
  return self.stop_generation
395
 
396
+ # Use torch.cuda.amp for mixed precision if available
397
  if self.device == "cuda" and torch.cuda.is_available():
398
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
399
+ outputs = self.model.generate(
400
  **inputs,
401
  max_new_tokens=None,
402
  cfg_scale=cfg_scale,
 
408
  refresh_negative=True,
409
  )
410
  else:
411
+ outputs = self.model.generate(
412
  **inputs,
413
  max_new_tokens=None,
414
  cfg_scale=cfg_scale,
 
423
  print(f"Error in generation thread: {e}")
424
  import traceback
425
  traceback.print_exc()
 
426
  audio_streamer.end()
427
 
428
+ def convert_to_16_bit_wav(self, data):
429
+ """Convert audio data to 16-bit WAV format."""
430
+ if torch.is_tensor(data):
431
+ data = data.detach().cpu().numpy()
432
+
433
+ data = np.array(data)
434
+
435
+ if np.max(np.abs(data)) > 1.0:
436
+ data = data / np.max(np.abs(data))
437
+
438
+ data = (data * 32767).astype(np.int16)
439
+ return data
440
+
441
  def stop_audio_generation(self):
442
+ """Stop the current audio generation."""
443
+ self.stop_generation = True
444
+ if self.current_streamer:
445
+ try:
446
+ self.current_streamer.end()
447
+ except:
448
+ pass
 
 
 
449
 
450
 
451
  def create_chat_interface(chat_instance: VibeVoiceChat):
452
+ """Create a simplified Gradio ChatInterface for VibeVoice."""
453
 
454
+ # Get available voices
455
  voice_options = list(chat_instance.available_voices.keys())
456
  if not voice_options:
457
  voice_options = ["Default"]
 
459
  default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
460
  default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
461
 
462
+ # Define the chat function that returns audio
463
+ def chat_fn(message: str, history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
464
+ """Process chat message and generate audio response."""
465
+
466
+ # Extract text from message
467
+ if isinstance(message, dict):
468
+ text = message.get("text", "")
469
+ else:
470
+ text = message
471
+
472
+ if not text.strip():
473
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
+ try:
476
+ # Generate audio stream
477
+ audio_generator = chat_instance.generate_audio_stream(
478
+ text, history, voice_1, voice_2, num_speakers, cfg_scale
479
+ )
480
+
481
+ # Collect all audio data
482
+ audio_data = None
483
+ for audio_chunk in audio_generator:
484
+ if audio_chunk is not None:
485
+ audio_data = audio_chunk
486
+
487
+ # Return audio file path or error message
488
+ if audio_data:
489
+ # Save audio to temporary file
490
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
491
+ sample_rate, audio_array = audio_data
492
+ sf.write(tmp_file.name, audio_array, sample_rate)
493
+ # Return the file path directly
494
+ return tmp_file.name
495
+ else:
496
+ return "Failed to generate audio"
497
+
498
+ except Exception as e:
499
+ print(f"Error in chat_fn: {e}")
500
+ import traceback
501
+ traceback.print_exc()
502
+ return f"Error: {str(e)}"
503
+
504
+ # Create the interface using Blocks for more control
505
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True) as interface:
506
+ gr.Markdown("# 🎙️ VibeVoice Chat\nGenerate natural dialogue audio with AI voices")
507
 
508
  with gr.Row():
509
  with gr.Column(scale=1):
 
555
  lines=3
556
  )
557
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  audio_output = gr.Audio(
559
  label="Generated Audio",
560
+ type="filepath",
 
561
  autoplay=True,
 
562
  visible=False
563
  )
564
 
565
  with gr.Row():
566
+ submit = gr.Button("🎵 Generate Audio", variant="primary")
567
+ clear = gr.Button("🗑️ Clear")
 
568
 
569
+ # Example messages
570
  gr.Examples(
571
  examples=[
572
  "Hello! How are you doing today?",
 
578
  label="Example Messages"
579
  )
580
 
581
+ # Set up event handlers
582
+ def process_and_display(message, history, voice_1, voice_2, num_speakers, cfg_scale):
583
+ """Process message and update both chatbot and audio."""
584
+ # Add user message to history
585
  history = history or []
586
+ history.append({"role": "user", "content": message})
587
 
588
+ # Generate audio
589
+ audio_path = chat_fn(message, history, voice_1, voice_2, num_speakers, cfg_scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
+ # Add assistant response with audio
592
+ if audio_path and audio_path.endswith('.wav'):
593
+ history.append({"role": "assistant", "content": f"🎵 Audio generated successfully"})
594
+ return history, audio_path, gr.update(visible=True), ""
595
+ else:
596
+ history.append({"role": "assistant", "content": audio_path or "Failed to generate audio"})
597
+ return history, None, gr.update(visible=False), ""
598
 
599
+ submit.click(
600
+ fn=process_and_display,
 
 
 
 
 
 
 
 
 
 
601
  inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
602
+ outputs=[chatbot, audio_output, audio_output, msg],
603
  queue=True
604
  )
605
 
606
  msg.submit(
607
+ fn=process_and_display,
 
 
 
 
 
608
  inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
609
+ outputs=[chatbot, audio_output, audio_output, msg],
610
  queue=True
611
  )
612
 
613
+ clear.click(lambda: ([], None, gr.update(visible=False)), outputs=[chatbot, audio_output, audio_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
 
615
  return interface
616
 
 
665
 
666
  if chat_instance.device == "cpu":
667
  print("\n⚠️ WARNING: Running on CPU - generation will be VERY slow!")
668
+ print(" For faster generation, ensure you have:")
669
+ print(" 1. NVIDIA GPU with CUDA support")
670
+ print(" 2. PyTorch with CUDA installed: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
671
 
672
  # Launch the interface
673
  interface.queue(max_size=10).launch(