akhaliq HF Staff commited on
Commit
b7fc0b0
Β·
verified Β·
1 Parent(s): c2e9ecc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -308
app.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import argparse
2
  import os
3
  import tempfile
@@ -10,7 +14,7 @@ import librosa
10
  import soundfile as sf
11
  import torch
12
  from pathlib import Path
13
- from typing import Iterator, Dict, Any, List
14
 
15
  # Clone and setup VibeVoice if not already present
16
  vibevoice_dir = Path('./VibeVoice')
@@ -87,20 +91,6 @@ from transformers import set_seed
87
  logging.set_verbosity_info()
88
  logger = logging.get_logger(__name__)
89
 
90
- # --- Helper function for audio conversion ---
91
- def convert_to_16_bit_wav(data: np.ndarray | torch.Tensor) -> np.ndarray:
92
- """Convert audio data to 16-bit WAV format (numpy int16)."""
93
- if torch.is_tensor(data):
94
- data = data.detach().cpu().numpy()
95
-
96
- data = np.array(data, dtype=np.float32) # Ensure float32 before scaling
97
-
98
- # Normalize to -1 to 1 if necessary
99
- if np.max(np.abs(data)) > 1.0:
100
- data = data / np.max(np.abs(data))
101
-
102
- data = (data * 32767).astype(np.int16)
103
- return data
104
 
105
  class VibeVoiceChat:
106
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
@@ -110,8 +100,7 @@ class VibeVoiceChat:
110
  self.inference_steps = inference_steps
111
  self.is_generating = False
112
  self.stop_generation = False
113
- self.current_streamer: AudioStreamer | None = None
114
- self.complete_audio_buffer: List[np.ndarray] = [] # To store all generated audio for final download
115
 
116
  # Check GPU availability and CUDA version
117
  if torch.cuda.is_available():
@@ -119,8 +108,10 @@ class VibeVoiceChat:
119
  print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
120
  print(f" CUDA Version: {torch.version.cuda}")
121
  print(f" PyTorch CUDA: {torch.cuda.is_available()}")
122
- torch.cuda.set_per_process_memory_fraction(0.95) # Set memory fraction to avoid OOM
123
- torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for faster computation on Ampere GPUs
 
 
124
  torch.backends.cudnn.allow_tf32 = True
125
  else:
126
  print("βœ— No GPU detected, using CPU (generation will be VERY slow)")
@@ -178,13 +169,15 @@ class VibeVoiceChat:
178
  load_time = time.time() - start_time
179
  print(f"βœ“ Model loaded in {load_time:.2f} seconds")
180
 
 
181
  if hasattr(self.model, 'device'):
182
  print(f"Model device: {self.model.device}")
183
 
184
  def setup_voice_presets(self):
185
  """Setup voice presets from the voices directory."""
186
- voices_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voices")
187
 
 
188
  if not os.path.exists(voices_dir):
189
  os.makedirs(voices_dir)
190
  print(f"Created voices directory at {voices_dir}")
@@ -193,16 +186,19 @@ class VibeVoiceChat:
193
  self.available_voices = {}
194
  audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
195
 
 
196
  for file in os.listdir(voices_dir):
197
  if file.lower().endswith(audio_extensions):
198
  name = os.path.splitext(file)[0]
199
  self.available_voices[name] = os.path.join(voices_dir, file)
200
 
 
201
  self.available_voices = dict(sorted(self.available_voices.items()))
202
 
203
  if not self.available_voices:
204
  print(f"Warning: No voice files found in {voices_dir}")
205
  print("Using default (zero) voice samples. Add audio files to the voices directory for better results.")
 
206
  self.available_voices = {"Default": None}
207
  else:
208
  print(f"Found {len(self.available_voices)} voice presets: {', '.join(self.available_voices.keys())}")
@@ -210,13 +206,15 @@ class VibeVoiceChat:
210
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
211
  """Read and preprocess audio file."""
212
  try:
213
- wav, sr = librosa.load(audio_path, sr=None, mono=True)
 
 
214
  if sr != target_sr:
215
- wav = librosa.resample(y=wav, orig_sr=sr, target_sr=target_sr)
216
  return wav
217
  except Exception as e:
218
  print(f"Error reading audio {audio_path}: {e}")
219
- return np.zeros(target_sr, dtype=np.float32)
220
 
221
  def format_script(self, message: str, num_speakers: int = 2) -> str:
222
  """Format input message into a script with speaker assignments."""
@@ -228,9 +226,11 @@ class VibeVoiceChat:
228
  if not line:
229
  continue
230
 
 
231
  if line.startswith('Speaker ') and ':' in line:
232
  formatted_lines.append(line)
233
  else:
 
234
  speaker_id = i % num_speakers
235
  formatted_lines.append(f"Speaker {speaker_id}: {line}")
236
 
@@ -239,50 +239,59 @@ class VibeVoiceChat:
239
  def generate_audio_stream(
240
  self,
241
  message: str,
 
242
  voice_1: str,
243
  voice_2: str,
244
  num_speakers: int,
245
  cfg_scale: float
246
- ) -> Iterator[tuple]: # This generator yields (sample_rate, audio_chunk_numpy_int16)
247
- """
248
- Generate audio stream from text input, implementing buffering for smoother streaming.
249
- Yields (sample_rate, audio_chunk_numpy_int16) tuples as audio becomes available.
250
- """
251
  try:
252
  self.stop_generation = False
253
  self.is_generating = True
254
- self.complete_audio_buffer = [] # Reset buffer for new generation
255
 
 
256
  if not message.strip():
257
- self.is_generating = False
258
  yield None
259
  return
260
 
 
261
  formatted_script = self.format_script(message, num_speakers)
 
 
262
 
 
 
 
 
263
  selected_voices = []
264
  if voice_1 and voice_1 != "Default":
265
  selected_voices.append(voice_1)
266
  if num_speakers > 1 and voice_2 and voice_2 != "Default":
267
  selected_voices.append(voice_2)
268
 
 
269
  voice_samples = []
270
- target_sr = 24000
271
  for i in range(num_speakers):
 
272
  if i < len(selected_voices):
273
  voice_name = selected_voices[i]
274
  if voice_name in self.available_voices and self.available_voices[voice_name]:
275
- audio_data = self.read_audio(self.available_voices[voice_name], target_sr=target_sr)
276
  else:
277
- audio_data = np.zeros(target_sr, dtype=np.float32)
278
  else:
 
279
  if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
280
- audio_data = self.read_audio(self.available_voices[selected_voices[0]], target_sr=target_sr)
281
  else:
282
- audio_data = np.zeros(target_sr, dtype=np.float32)
283
 
284
  voice_samples.append(audio_data)
285
 
 
 
 
286
  inputs = self.processor(
287
  text=[formatted_script],
288
  voice_samples=[voice_samples],
@@ -291,9 +300,15 @@ class VibeVoiceChat:
291
  return_attention_mask=True,
292
  )
293
 
 
294
  if self.device == "cuda":
295
  inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
 
 
 
 
296
 
 
297
  audio_streamer = AudioStreamer(
298
  batch_size=1,
299
  stop_signal=None,
@@ -302,89 +317,86 @@ class VibeVoiceChat:
302
 
303
  self.current_streamer = audio_streamer
304
 
 
305
  generation_thread = threading.Thread(
306
  target=self._generate_with_streamer,
307
  args=(inputs, cfg_scale, audio_streamer)
308
  )
309
  generation_thread.start()
310
 
311
- # Give the generation thread a moment to start producing output
312
- time.sleep(1.0) # Increased from 0.5s for stability
313
 
314
- audio_output_stream = audio_streamer.get_stream(0)
 
 
315
 
316
- # Buffering logic for smoother Gradio streaming
317
- pending_chunks: List[np.ndarray] = []
318
- min_yield_interval_seconds = 1.0 # Yield at least every 1 second
319
- min_chunk_size_samples = target_sr * 0.5 # At least 0.5 seconds of audio per chunk yielded to Gradio
320
- last_yield_time = time.time()
321
 
322
- for audio_chunk_raw in audio_output_stream:
323
  if self.stop_generation:
324
  audio_streamer.end()
325
  break
326
 
327
- # Convert raw chunk to numpy float32
328
- if torch.is_tensor(audio_chunk_raw):
329
- if audio_chunk_raw.dtype == torch.bfloat16:
330
- audio_chunk_raw = audio_chunk_raw.float()
331
- audio_np = audio_chunk_raw.cpu().numpy().astype(np.float32)
 
 
332
  else:
333
- audio_np = np.array(audio_chunk_raw, dtype=np.float32)
334
 
 
335
  if len(audio_np.shape) > 1:
336
  audio_np = audio_np.squeeze()
337
 
338
- # Append to complete buffer (for final download)
339
- self.complete_audio_buffer.append(audio_np)
340
-
341
- # Append to pending chunks for streaming to Gradio
342
- pending_chunks.append(audio_np)
343
- current_pending_size = sum(len(c) for c in pending_chunks)
344
 
345
- current_time = time.time()
346
-
347
- should_yield = False
348
- if current_pending_size >= min_chunk_size_samples:
349
- should_yield = True
350
- elif (current_time - last_yield_time) >= min_yield_interval_seconds and pending_chunks:
351
- should_yield = True
352
-
353
- if should_yield:
354
- combined_chunk = np.concatenate(pending_chunks)
355
- yield (target_sr, convert_to_16_bit_wav(combined_chunk)) # Convert to int16 before yielding
356
- pending_chunks = []
357
- last_yield_time = current_time
358
 
359
- # Yield any remaining chunks after the loop finishes
360
- if pending_chunks and not self.stop_generation:
361
- combined_chunk = np.concatenate(pending_chunks)
362
- yield (target_sr, convert_to_16_bit_wav(combined_chunk))
363
 
364
- generation_thread.join(timeout=10.0) # Ensure generation thread completes
 
 
 
 
 
 
 
 
 
365
 
366
- # Clean up
367
  self.current_streamer = None
368
  self.is_generating = False
369
 
370
  except Exception as e:
371
- print(f"Error in generate_audio_stream: {e}")
372
  import traceback
373
  traceback.print_exc()
374
  self.is_generating = False
375
  self.current_streamer = None
376
- self.complete_audio_buffer = [] # Clear buffer on error
377
  yield None
378
 
379
- def _generate_with_streamer(self, inputs: Dict[str, Any], cfg_scale: float, audio_streamer: AudioStreamer):
380
  """Helper method to run generation with streamer."""
381
  try:
382
  def check_stop():
383
  return self.stop_generation
384
 
 
385
  if self.device == "cuda" and torch.cuda.is_available():
386
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
387
- self.model.generate(
388
  **inputs,
389
  max_new_tokens=None,
390
  cfg_scale=cfg_scale,
@@ -396,7 +408,7 @@ class VibeVoiceChat:
396
  refresh_negative=True,
397
  )
398
  else:
399
- self.model.generate(
400
  **inputs,
401
  max_new_tokens=None,
402
  cfg_scale=cfg_scale,
@@ -411,232 +423,91 @@ class VibeVoiceChat:
411
  print(f"Error in generation thread: {e}")
412
  import traceback
413
  traceback.print_exc()
414
- finally:
415
- audio_streamer.end()
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  def stop_audio_generation(self):
418
- """Signal to stop the current audio generation."""
419
- if self.is_generating:
420
- print("πŸ›‘ Stop signal received.")
421
- self.stop_generation = True
422
- if self.current_streamer:
423
- try:
424
- self.current_streamer.end()
425
- except Exception as e:
426
- print(f"Error ending streamer: {e}")
427
- self.is_generating = False
428
- self.complete_audio_buffer = []
429
- else:
430
- print("No active generation to stop.")
431
 
432
 
433
  def create_chat_interface(chat_instance: VibeVoiceChat):
434
- """Create a simplified Gradio ChatInterface for VibeVoice with audio streaming."""
435
 
 
436
  voice_options = list(chat_instance.available_voices.keys())
437
  if not voice_options:
438
  voice_options = ["Default"]
439
 
440
  default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
441
  default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
442
-
443
- # Custom CSS for modern aesthetics
444
- custom_css = """
445
- .gradio-container {
446
- font-family: 'Inter', sans-serif;
447
- background: linear-gradient(135deg, #f0f2f5 0%, #e0e6ed 100%);
448
- color: #333;
449
- }
450
- .main-header {
451
- background: linear-gradient(45deg, #4A00E0 0%, #8E2DE2 100%);
452
- padding: 20px 30px;
453
- border-radius: 15px;
454
- margin-bottom: 25px;
455
- text-align: center;
456
- box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
457
- }
458
- .main-header h1 {
459
- color: white;
460
- font-size: 2.8em;
461
- font-weight: 800;
462
- margin: 0;
463
- letter-spacing: -1px;
464
- text-shadow: 0 3px 5px rgba(0,0,0,0.2);
465
- }
466
- .main-header p {
467
- color: rgba(255,255,255,0.85);
468
- font-size: 1.1em;
469
- margin-top: 10px;
470
- }
471
- .settings-card, .generation-card {
472
- background: rgba(255, 255, 255, 0.9);
473
- border: 1px solid #dcdfe6;
474
- border-radius: 12px;
475
- padding: 20px;
476
- box-shadow: 0 4px 15px rgba(0, 0, 0, 0.08);
477
- transition: all 0.3s ease;
478
- }
479
- .settings-card:hover, .generation-card:hover {
480
- box-shadow: 0 6px 20px rgba(0, 0, 0, 0.12);
481
- transform: translateY(-2px);
482
- }
483
- .gradio-output {
484
- border-radius: 10px;
485
- background-color: #fcfcfc;
486
- }
487
- .gradio-button {
488
- border-radius: 8px !important;
489
- font-weight: 600;
490
- padding: 10px 20px;
491
- transition: all 0.2s ease-in-out;
492
- }
493
- .gradio-button.primary {
494
- background: linear-gradient(45deg, #4CAF50 0%, #8BC34A 100%) !important;
495
- color: white !important;
496
- border: none !important;
497
- }
498
- .gradio-button.primary:hover {
499
- opacity: 0.9;
500
- transform: translateY(-1px);
501
- }
502
- .gradio-button.secondary {
503
- background: linear-gradient(45deg, #FF5722 0%, #FFC107 100%) !important;
504
- color: white !important;
505
- border: none !important;
506
- }
507
- .gradio-button.secondary:hover {
508
- opacity: 0.9;
509
- transform: translateY(-1px);
510
- }
511
- .gradio-button.clear {
512
- background: #90A4AE !important;
513
- color: white !important;
514
- border: none !important;
515
- }
516
- .gradio-button.clear:hover {
517
- opacity: 0.9;
518
- transform: translateY(-1px);
519
- }
520
- .gradio-input {
521
- border-radius: 8px !important;
522
- border: 1px solid #ced4da !important;
523
- }
524
- .gradio-label {
525
- font-weight: 700;
526
- color: #495057;
527
- margin-bottom: 5px;
528
- }
529
- .chatbot {
530
- border: 1px solid #e0e0e0 !important;
531
- border-radius: 10px !important;
532
- box-shadow: 0 2px 10px rgba(0,0,0,0.05);
533
- }
534
- .log-output {
535
- font-family: 'JetBrains Mono', monospace;
536
- background-color: #f8f9fa !important;
537
- border-radius: 8px !important;
538
- border: 1px solid #e9ecef !important;
539
- color: #495057 !important;
540
- min-height: 80px;
541
- }
542
- .audio-output {
543
- border-radius: 10px !important;
544
- border: 1px solid #e0e0e0 !important;
545
- background-color: #f8f9fa !important;
546
- }
547
- """
548
 
549
- # Gradio handler function that coordinates UI updates
550
- def process_and_display_stream(message_text: str, history: List[Dict[str, str]], voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
551
- history = history or []
552
- user_message_entry = {"role": "user", "content": message_text}
553
-
554
- # Initial state: user message added, text input disabled, buttons updated, audio cleared/hidden
555
- # This yield ensures immediate UI feedback
556
- yield (
557
- history + [user_message_entry], # Add user message to chatbot immediately
558
- gr.update(value="", interactive=False), # Clear text input and disable
559
- gr.update(value=None, visible=True), # Clear streaming audio, make visible
560
- gr.update(value=None, visible=False), # Clear complete audio, hide
561
- "πŸŽ™οΈ Starting audio generation...", # Initial log message
562
- gr.update(interactive=False, value="Generating..."), # Disable submit button
563
- gr.update(visible=True) # Show stop button
564
- )
565
-
566
- log_message = ""
567
- generated_any_audio = False
568
 
569
- # Call the chat_instance's audio generator
570
- audio_stream_generator = chat_instance.generate_audio_stream(
571
- message_text, voice_1, voice_2, num_speakers, cfg_scale
572
- )
 
 
 
 
573
 
574
- # Loop through the streaming audio chunks
575
- for chunk_data in audio_stream_generator:
576
- if chat_instance.stop_generation:
577
- log_message = "πŸ›‘ Audio generation stopped."
578
- break
579
 
580
- if chunk_data is not None:
581
- generated_any_audio = True
582
- log_message = "🎡 Streaming audio..."
583
- # Yield current chunk to streaming audio component
584
- # Other components remain static during streaming
585
- yield (
586
- history + [user_message_entry], # Chatbot state
587
- gr.update(interactive=False), # Text input disabled
588
- chunk_data, # Streaming audio chunk
589
- gr.update(visible=False), # Complete audio hidden
590
- log_message, # Log update
591
- gr.update(interactive=False, value="Generating..."), # Submit button still disabled
592
- gr.update(visible=True) # Stop button still visible
593
- )
594
  else:
595
- # None indicates an error or unexpected end from the generator
596
- log_message = "❌ Error during audio generation."
597
- break
598
-
599
- # After generation (or stop/error), prepare final updates
600
- final_chatbot_history = history + [user_message_entry]
601
- final_streaming_audio_update = gr.update(value=None, visible=False) # Hide streaming audio
602
- final_complete_audio_update = gr.update(value=None, visible=False) # Default to hidden
603
-
604
- if chat_instance.stop_generation:
605
- final_chatbot_history.append({"role": "assistant", "content": "🚫 Audio generation stopped."})
606
- log_message = "πŸ›‘ Generation stopped by user."
607
- chat_instance.stop_generation = False # Reset flag for next run
608
- elif generated_any_audio and chat_instance.complete_audio_buffer:
609
- # Concatenate all collected audio chunks for the final downloadable audio
610
- complete_audio_data_np = np.concatenate(chat_instance.complete_audio_buffer)
611
- final_complete_audio_update = gr.update(value=(24000, convert_to_16_bit_wav(complete_audio_data_np)), visible=True)
612
- final_chatbot_history.append({"role": "assistant", "content": "βœ… Audio generated successfully! Listen below and download."})
613
- log_message = "✨ Generation complete! See 'Complete Audio' below."
614
- else:
615
- final_chatbot_history.append({"role": "assistant", "content": "❌ Failed to generate audio."})
616
- log_message = "❌ Generation failed or no audio produced."
617
 
618
- # Final yield to update all components after streaming
619
- yield (
620
- final_chatbot_history,
621
- gr.update(value="", interactive=True), # Re-enable text input
622
- final_streaming_audio_update,
623
- final_complete_audio_update,
624
- log_message,
625
- gr.update(interactive=True, value="🎡 Generate Audio"), # Re-enable submit button
626
- gr.update(visible=False) # Hide stop button
627
- )
628
-
629
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True, css=custom_css) as interface:
630
- gr.HTML("""
631
- <div class="main-header">
632
- <h1>πŸŽ™οΈ VibeVoice Chat - Streamed Audio</h1>
633
- <p>Generate natural dialogue audio with AI voices</p>
634
- </div>
635
- """)
636
 
637
  with gr.Row():
638
- with gr.Column(scale=1, elem_classes="settings-card"):
639
- gr.Markdown("### πŸŽ›οΈ **Voice & Generation Settings**")
640
 
641
  voice_1 = gr.Dropdown(
642
  choices=voice_options,
@@ -670,37 +541,140 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
670
  info="Guidance strength (higher = more adherence to text)"
671
  )
672
 
673
- with gr.Column(scale=2, elem_classes="generation-card"):
674
  chatbot = gr.Chatbot(
675
  label="Conversation",
676
- height=300, # Adjusted height
677
  type="messages",
678
- elem_id="chatbot",
679
- elem_classes="chatbot"
680
  )
681
 
682
  msg = gr.Textbox(
683
  label="Message",
684
  placeholder="Type your message or paste a script...",
685
- lines=3,
686
- elem_classes="gradio-input"
687
  )
688
 
689
- # Log output for generation status
690
- log_output = gr.Textbox(
691
- label="Generation Log",
692
- lines=2,
693
- max_lines=5,
694
- interactive=False,
695
- value="Ready to generate audio.",
696
- elem_classes="log-output"
697
- )
698
-
699
- # Streaming audio component
700
  audio_output = gr.Audio(
701
- label="Streaming Audio (Real-time Playback)",
702
- type="numpy", # Expects (sr, np_array)
703
- streaming=True,
704
  autoplay=True,
705
- visible=True, # Start visible but empty
706
- show_download_but
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VibeVoice Simple Chat Interface - Streamlined Audio Generation Demo
3
+ """
4
+
5
  import argparse
6
  import os
7
  import tempfile
 
14
  import soundfile as sf
15
  import torch
16
  from pathlib import Path
17
+ from typing import Iterator, Dict, Any
18
 
19
  # Clone and setup VibeVoice if not already present
20
  vibevoice_dir = Path('./VibeVoice')
 
91
  logging.set_verbosity_info()
92
  logger = logging.get_logger(__name__)
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  class VibeVoiceChat:
96
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
 
100
  self.inference_steps = inference_steps
101
  self.is_generating = False
102
  self.stop_generation = False
103
+ self.current_streamer = None
 
104
 
105
  # Check GPU availability and CUDA version
106
  if torch.cuda.is_available():
 
108
  print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
109
  print(f" CUDA Version: {torch.version.cuda}")
110
  print(f" PyTorch CUDA: {torch.cuda.is_available()}")
111
+ # Set memory fraction to avoid OOM
112
+ torch.cuda.set_per_process_memory_fraction(0.95)
113
+ # Enable TF32 for faster computation on Ampere GPUs
114
+ torch.backends.cuda.matmul.allow_tf32 = True
115
  torch.backends.cudnn.allow_tf32 = True
116
  else:
117
  print("βœ— No GPU detected, using CPU (generation will be VERY slow)")
 
169
  load_time = time.time() - start_time
170
  print(f"βœ“ Model loaded in {load_time:.2f} seconds")
171
 
172
+ # Print model device
173
  if hasattr(self.model, 'device'):
174
  print(f"Model device: {self.model.device}")
175
 
176
  def setup_voice_presets(self):
177
  """Setup voice presets from the voices directory."""
178
+ voices_dir = os.path.join(os.path.dirname(__file__), "voices")
179
 
180
+ # Create voices directory if it doesn't exist
181
  if not os.path.exists(voices_dir):
182
  os.makedirs(voices_dir)
183
  print(f"Created voices directory at {voices_dir}")
 
186
  self.available_voices = {}
187
  audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')
188
 
189
+ # Scan for audio files
190
  for file in os.listdir(voices_dir):
191
  if file.lower().endswith(audio_extensions):
192
  name = os.path.splitext(file)[0]
193
  self.available_voices[name] = os.path.join(voices_dir, file)
194
 
195
+ # Sort voices alphabetically
196
  self.available_voices = dict(sorted(self.available_voices.items()))
197
 
198
  if not self.available_voices:
199
  print(f"Warning: No voice files found in {voices_dir}")
200
  print("Using default (zero) voice samples. Add audio files to the voices directory for better results.")
201
+ # Add a default "None" option
202
  self.available_voices = {"Default": None}
203
  else:
204
  print(f"Found {len(self.available_voices)} voice presets: {', '.join(self.available_voices.keys())}")
 
206
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
207
  """Read and preprocess audio file."""
208
  try:
209
+ wav, sr = sf.read(audio_path)
210
+ if len(wav.shape) > 1:
211
+ wav = np.mean(wav, axis=1)
212
  if sr != target_sr:
213
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
214
  return wav
215
  except Exception as e:
216
  print(f"Error reading audio {audio_path}: {e}")
217
+ return np.zeros(24000) # Return 1 second of silence as fallback
218
 
219
  def format_script(self, message: str, num_speakers: int = 2) -> str:
220
  """Format input message into a script with speaker assignments."""
 
226
  if not line:
227
  continue
228
 
229
+ # Check if already formatted
230
  if line.startswith('Speaker ') and ':' in line:
231
  formatted_lines.append(line)
232
  else:
233
+ # Auto-assign speakers in rotation
234
  speaker_id = i % num_speakers
235
  formatted_lines.append(f"Speaker {speaker_id}: {line}")
236
 
 
239
  def generate_audio_stream(
240
  self,
241
  message: str,
242
+ history: list,
243
  voice_1: str,
244
  voice_2: str,
245
  num_speakers: int,
246
  cfg_scale: float
247
+ ) -> Iterator[tuple]:
248
+ """Generate audio stream from text input."""
 
 
 
249
  try:
250
  self.stop_generation = False
251
  self.is_generating = True
 
252
 
253
+ # Validate inputs
254
  if not message.strip():
 
255
  yield None
256
  return
257
 
258
+ # Format the script
259
  formatted_script = self.format_script(message, num_speakers)
260
+ print(f"Formatted script:\n{formatted_script}")
261
+ print(f"Using device: {self.device}")
262
 
263
+ # Start timing
264
+ start_time = time.time()
265
+
266
+ # Select voices based on number of speakers
267
  selected_voices = []
268
  if voice_1 and voice_1 != "Default":
269
  selected_voices.append(voice_1)
270
  if num_speakers > 1 and voice_2 and voice_2 != "Default":
271
  selected_voices.append(voice_2)
272
 
273
+ # Load voice samples
274
  voice_samples = []
 
275
  for i in range(num_speakers):
276
+ # Use the appropriate voice for each speaker
277
  if i < len(selected_voices):
278
  voice_name = selected_voices[i]
279
  if voice_name in self.available_voices and self.available_voices[voice_name]:
280
+ audio_data = self.read_audio(self.available_voices[voice_name])
281
  else:
282
+ audio_data = np.zeros(24000) # Default silence
283
  else:
284
+ # Use first voice or default if not enough voices selected
285
  if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
286
+ audio_data = self.read_audio(self.available_voices[selected_voices[0]])
287
  else:
288
+ audio_data = np.zeros(24000) # Default silence
289
 
290
  voice_samples.append(audio_data)
291
 
292
+ print(f"Loaded {len(voice_samples)} voice samples")
293
+
294
+ # Process inputs
295
  inputs = self.processor(
296
  text=[formatted_script],
297
  voice_samples=[voice_samples],
 
300
  return_attention_mask=True,
301
  )
302
 
303
+ # Move to device and ensure correct dtype
304
  if self.device == "cuda":
305
  inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
306
+ print(f"βœ“ Inputs moved to GPU")
307
+ # Check GPU memory
308
+ if torch.cuda.is_available():
309
+ print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
310
 
311
+ # Create audio streamer
312
  audio_streamer = AudioStreamer(
313
  batch_size=1,
314
  stop_signal=None,
 
317
 
318
  self.current_streamer = audio_streamer
319
 
320
+ # Start generation in separate thread
321
  generation_thread = threading.Thread(
322
  target=self._generate_with_streamer,
323
  args=(inputs, cfg_scale, audio_streamer)
324
  )
325
  generation_thread.start()
326
 
327
+ # Wait briefly for generation to start
328
+ time.sleep(1)
329
 
330
+ # Stream audio chunks
331
+ sample_rate = 24000
332
+ audio_stream = audio_streamer.get_stream(0)
333
 
334
+ all_audio_chunks = []
335
+ chunk_count = 0
 
 
 
336
 
337
+ for audio_chunk in audio_stream:
338
  if self.stop_generation:
339
  audio_streamer.end()
340
  break
341
 
342
+ chunk_count += 1
343
+
344
+ # Convert to numpy
345
+ if torch.is_tensor(audio_chunk):
346
+ if audio_chunk.dtype == torch.bfloat16:
347
+ audio_chunk = audio_chunk.float()
348
+ audio_np = audio_chunk.cpu().numpy().astype(np.float32)
349
  else:
350
+ audio_np = np.array(audio_chunk, dtype=np.float32)
351
 
352
+ # Ensure 1D
353
  if len(audio_np.shape) > 1:
354
  audio_np = audio_np.squeeze()
355
 
356
+ # Convert to 16-bit
357
+ audio_16bit = self.convert_to_16_bit_wav(audio_np)
358
+ all_audio_chunks.append(audio_16bit)
 
 
 
359
 
360
+ # Yield accumulated audio
361
+ if all_audio_chunks:
362
+ complete_audio = np.concatenate(all_audio_chunks)
363
+ yield (sample_rate, complete_audio)
 
 
 
 
 
 
 
 
 
364
 
365
+ # Wait for generation to complete
366
+ generation_thread.join(timeout=5.0)
 
 
367
 
368
+ # Final yield with complete audio
369
+ if all_audio_chunks:
370
+ complete_audio = np.concatenate(all_audio_chunks)
371
+ generation_time = time.time() - start_time
372
+ audio_duration = len(complete_audio) / sample_rate
373
+ print(f"βœ“ Generation complete:")
374
+ print(f" Time taken: {generation_time:.2f} seconds")
375
+ print(f" Audio duration: {audio_duration:.2f} seconds")
376
+ print(f" Real-time factor: {audio_duration/generation_time:.2f}x")
377
+ yield (sample_rate, complete_audio)
378
 
 
379
  self.current_streamer = None
380
  self.is_generating = False
381
 
382
  except Exception as e:
383
+ print(f"Error in generation: {e}")
384
  import traceback
385
  traceback.print_exc()
386
  self.is_generating = False
387
  self.current_streamer = None
 
388
  yield None
389
 
390
+ def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
391
  """Helper method to run generation with streamer."""
392
  try:
393
  def check_stop():
394
  return self.stop_generation
395
 
396
+ # Use torch.cuda.amp for mixed precision if available
397
  if self.device == "cuda" and torch.cuda.is_available():
398
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
399
+ outputs = self.model.generate(
400
  **inputs,
401
  max_new_tokens=None,
402
  cfg_scale=cfg_scale,
 
408
  refresh_negative=True,
409
  )
410
  else:
411
+ outputs = self.model.generate(
412
  **inputs,
413
  max_new_tokens=None,
414
  cfg_scale=cfg_scale,
 
423
  print(f"Error in generation thread: {e}")
424
  import traceback
425
  traceback.print_exc()
426
+ audio_streamer.end()
427
+
428
+ def convert_to_16_bit_wav(self, data):
429
+ """Convert audio data to 16-bit WAV format."""
430
+ if torch.is_tensor(data):
431
+ data = data.detach().cpu().numpy()
432
+
433
+ data = np.array(data)
434
+
435
+ if np.max(np.abs(data)) > 1.0:
436
+ data = data / np.max(np.abs(data))
437
+
438
+ data = (data * 32767).astype(np.int16)
439
+ return data
440
 
441
  def stop_audio_generation(self):
442
+ """Stop the current audio generation."""
443
+ self.stop_generation = True
444
+ if self.current_streamer:
445
+ try:
446
+ self.current_streamer.end()
447
+ except:
448
+ pass
 
 
 
 
 
 
449
 
450
 
451
  def create_chat_interface(chat_instance: VibeVoiceChat):
452
+ """Create a simplified Gradio ChatInterface for VibeVoice."""
453
 
454
+ # Get available voices
455
  voice_options = list(chat_instance.available_voices.keys())
456
  if not voice_options:
457
  voice_options = ["Default"]
458
 
459
  default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
460
  default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
+ # Define the chat function that returns audio
463
+ def chat_fn(message: str, history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
464
+ """Process chat message and generate audio response."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
+ # Extract text from message
467
+ if isinstance(message, dict):
468
+ text = message.get("text", "")
469
+ else:
470
+ text = message
471
+
472
+ if not text.strip():
473
+ return ""
474
 
475
+ try:
476
+ # Generate audio stream
477
+ audio_generator = chat_instance.generate_audio_stream(
478
+ text, history, voice_1, voice_2, num_speakers, cfg_scale
479
+ )
480
 
481
+ # Collect all audio data
482
+ audio_data = None
483
+ for audio_chunk in audio_generator:
484
+ if audio_chunk is not None:
485
+ audio_data = audio_chunk
486
+
487
+ # Return audio file path or error message
488
+ if audio_data:
489
+ # Save audio to temporary file
490
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
491
+ sample_rate, audio_array = audio_data
492
+ sf.write(tmp_file.name, audio_array, sample_rate)
493
+ # Return the file path directly
494
+ return tmp_file.name
495
  else:
496
+ return "Failed to generate audio"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
+ except Exception as e:
499
+ print(f"Error in chat_fn: {e}")
500
+ import traceback
501
+ traceback.print_exc()
502
+ return f"Error: {str(e)}"
503
+
504
+ # Create the interface using Blocks for more control
505
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True) as interface:
506
+ gr.Markdown("# πŸŽ™οΈ VibeVoice Chat\nGenerate natural dialogue audio with AI voices")
 
 
 
 
 
 
 
 
 
507
 
508
  with gr.Row():
509
+ with gr.Column(scale=1):
510
+ gr.Markdown("### Voice & Generation Settings")
511
 
512
  voice_1 = gr.Dropdown(
513
  choices=voice_options,
 
541
  info="Guidance strength (higher = more adherence to text)"
542
  )
543
 
544
+ with gr.Column(scale=2):
545
  chatbot = gr.Chatbot(
546
  label="Conversation",
547
+ height=400,
548
  type="messages",
549
+ elem_id="chatbot"
 
550
  )
551
 
552
  msg = gr.Textbox(
553
  label="Message",
554
  placeholder="Type your message or paste a script...",
555
+ lines=3
 
556
  )
557
 
 
 
 
 
 
 
 
 
 
 
 
558
  audio_output = gr.Audio(
559
+ label="Generated Audio",
560
+ type="filepath",
 
561
  autoplay=True,
562
+ visible=False
563
+ )
564
+
565
+ with gr.Row():
566
+ submit = gr.Button("🎡 Generate Audio", variant="primary")
567
+ clear = gr.Button("πŸ—‘οΈ Clear")
568
+
569
+ # Example messages
570
+ gr.Examples(
571
+ examples=[
572
+ "Hello! How are you doing today?",
573
+ "Speaker 0: Welcome to our podcast!\nSpeaker 1: Thanks for having me!",
574
+ "Tell me an interesting fact about space.",
575
+ "What's your favorite type of music and why?",
576
+ ],
577
+ inputs=msg,
578
+ label="Example Messages"
579
+ )
580
+
581
+ # Set up event handlers
582
+ def process_and_display(message, history, voice_1, voice_2, num_speakers, cfg_scale):
583
+ """Process message and update both chatbot and audio."""
584
+ # Add user message to history
585
+ history = history or []
586
+ history.append({"role": "user", "content": message})
587
+
588
+ # Generate audio
589
+ audio_path = chat_fn(message, history, voice_1, voice_2, num_speakers, cfg_scale)
590
+
591
+ # Add assistant response with audio
592
+ if audio_path and audio_path.endswith('.wav'):
593
+ history.append({"role": "assistant", "content": f"🎡 Audio generated successfully"})
594
+ return history, audio_path, gr.update(visible=True), ""
595
+ else:
596
+ history.append({"role": "assistant", "content": audio_path or "Failed to generate audio"})
597
+ return history, None, gr.update(visible=False), ""
598
+
599
+ submit.click(
600
+ fn=process_and_display,
601
+ inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
602
+ outputs=[chatbot, audio_output, audio_output, msg],
603
+ queue=True
604
+ )
605
+
606
+ msg.submit(
607
+ fn=process_and_display,
608
+ inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
609
+ outputs=[chatbot, audio_output, audio_output, msg],
610
+ queue=True
611
+ )
612
+
613
+ clear.click(lambda: ([], None, gr.update(visible=False)), outputs=[chatbot, audio_output, audio_output])
614
+
615
+ return interface
616
+
617
+
618
+ def parse_args():
619
+ parser = argparse.ArgumentParser(description="VibeVoice Chat Interface")
620
+ parser.add_argument(
621
+ "--model_path",
622
+ type=str,
623
+ default="microsoft/VibeVoice-1.5B",
624
+ help="Path to the VibeVoice model",
625
+ )
626
+ parser.add_argument(
627
+ "--device",
628
+ type=str,
629
+ default="cuda" if torch.cuda.is_available() else "cpu",
630
+ help="Device for inference",
631
+ )
632
+ parser.add_argument(
633
+ "--inference_steps",
634
+ type=int,
635
+ default=5,
636
+ help="Number of DDPM inference steps (lower = faster, higher = better quality)",
637
+ )
638
+
639
+ return parser.parse_args()
640
+
641
+
642
+ def main():
643
+ """Main function to run the chat interface."""
644
+ args = parse_args()
645
+
646
+ set_seed(42)
647
+
648
+ print("πŸŽ™οΈ Initializing VibeVoice Chat Interface...")
649
+
650
+ # Initialize chat instance
651
+ chat_instance = VibeVoiceChat(
652
+ model_path=args.model_path,
653
+ device=args.device,
654
+ inference_steps=args.inference_steps
655
+ )
656
+
657
+ # Create interface
658
+ interface = create_chat_interface(chat_instance)
659
+
660
+ print(f"πŸš€ Launching chat interface")
661
+ print(f"πŸ“ Model: {args.model_path}")
662
+ print(f"πŸ’» Device: {chat_instance.device}")
663
+ print(f"πŸ”’ Inference steps: {args.inference_steps}")
664
+ print(f"🎭 Available voices: {len(chat_instance.available_voices)}")
665
+
666
+ if chat_instance.device == "cpu":
667
+ print("\n⚠️ WARNING: Running on CPU - generation will be VERY slow!")
668
+ print(" For faster generation, ensure you have:")
669
+ print(" 1. NVIDIA GPU with CUDA support")
670
+ print(" 2. PyTorch with CUDA installed: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
671
+
672
+ # Launch the interface
673
+ interface.queue(max_size=10).launch(
674
+ show_error=True,
675
+ quiet=False,
676
+ )
677
+
678
+
679
+ if __name__ == "__main__":
680
+ main()