Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
import tempfile
|
@@ -88,21 +92,6 @@ logging.set_verbosity_info()
|
|
88 |
logger = logging.get_logger(__name__)
|
89 |
|
90 |
|
91 |
-
def convert_to_16_bit_wav(data):
|
92 |
-
"""Convert audio data to 16-bit WAV format."""
|
93 |
-
if torch.is_tensor(data):
|
94 |
-
data = data.detach().cpu().numpy()
|
95 |
-
|
96 |
-
data = np.array(data, dtype=np.float32)
|
97 |
-
|
98 |
-
# Normalize to -1 to 1 if necessary
|
99 |
-
if np.max(np.abs(data)) > 1.0:
|
100 |
-
data = data / np.max(np.abs(data))
|
101 |
-
|
102 |
-
data = (data * 32767).astype(np.int16)
|
103 |
-
return data
|
104 |
-
|
105 |
-
|
106 |
class VibeVoiceChat:
|
107 |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
|
108 |
"""Initialize the VibeVoice chat model."""
|
@@ -186,8 +175,7 @@ class VibeVoiceChat:
|
|
186 |
|
187 |
def setup_voice_presets(self):
|
188 |
"""Setup voice presets from the voices directory."""
|
189 |
-
|
190 |
-
voices_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voices")
|
191 |
|
192 |
# Create voices directory if it doesn't exist
|
193 |
if not os.path.exists(voices_dir):
|
@@ -218,14 +206,15 @@ class VibeVoiceChat:
|
|
218 |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
|
219 |
"""Read and preprocess audio file."""
|
220 |
try:
|
221 |
-
|
222 |
-
wav
|
|
|
223 |
if sr != target_sr:
|
224 |
-
wav = librosa.resample(
|
225 |
return wav
|
226 |
except Exception as e:
|
227 |
print(f"Error reading audio {audio_path}: {e}")
|
228 |
-
return np.zeros(
|
229 |
|
230 |
def format_script(self, message: str, num_speakers: int = 2) -> str:
|
231 |
"""Format input message into a script with speaker assignments."""
|
@@ -237,7 +226,7 @@ class VibeVoiceChat:
|
|
237 |
if not line:
|
238 |
continue
|
239 |
|
240 |
-
# Check if already formatted
|
241 |
if line.startswith('Speaker ') and ':' in line:
|
242 |
formatted_lines.append(line)
|
243 |
else:
|
@@ -256,18 +245,14 @@ class VibeVoiceChat:
|
|
256 |
num_speakers: int,
|
257 |
cfg_scale: float
|
258 |
) -> Iterator[tuple]:
|
259 |
-
"""
|
260 |
-
Generate audio stream from text input.
|
261 |
-
Yields (sample_rate, audio_chunk_numpy) tuples as audio becomes available.
|
262 |
-
"""
|
263 |
try:
|
264 |
self.stop_generation = False
|
265 |
self.is_generating = True
|
266 |
|
267 |
# Validate inputs
|
268 |
if not message.strip():
|
269 |
-
|
270 |
-
yield None, "❌ Error: Please provide a message."
|
271 |
return
|
272 |
|
273 |
# Format the script
|
@@ -275,6 +260,7 @@ class VibeVoiceChat:
|
|
275 |
print(f"Formatted script:\n{formatted_script}")
|
276 |
print(f"Using device: {self.device}")
|
277 |
|
|
|
278 |
start_time = time.time()
|
279 |
|
280 |
# Select voices based on number of speakers
|
@@ -286,19 +272,20 @@ class VibeVoiceChat:
|
|
286 |
|
287 |
# Load voice samples
|
288 |
voice_samples = []
|
289 |
-
target_sr = 24000
|
290 |
for i in range(num_speakers):
|
|
|
291 |
if i < len(selected_voices):
|
292 |
voice_name = selected_voices[i]
|
293 |
if voice_name in self.available_voices and self.available_voices[voice_name]:
|
294 |
-
audio_data = self.read_audio(self.available_voices[voice_name]
|
295 |
else:
|
296 |
-
audio_data = np.zeros(
|
297 |
else:
|
|
|
298 |
if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
|
299 |
-
audio_data = self.read_audio(self.available_voices[selected_voices[0]]
|
300 |
else:
|
301 |
-
audio_data = np.zeros(
|
302 |
|
303 |
voice_samples.append(audio_data)
|
304 |
|
@@ -317,6 +304,9 @@ class VibeVoiceChat:
|
|
317 |
if self.device == "cuda":
|
318 |
inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
|
319 |
print(f"✓ Inputs moved to GPU")
|
|
|
|
|
|
|
320 |
|
321 |
# Create audio streamer
|
322 |
audio_streamer = AudioStreamer(
|
@@ -327,45 +317,31 @@ class VibeVoiceChat:
|
|
327 |
|
328 |
self.current_streamer = audio_streamer
|
329 |
|
330 |
-
# Start generation in
|
331 |
generation_thread = threading.Thread(
|
332 |
target=self._generate_with_streamer,
|
333 |
args=(inputs, cfg_scale, audio_streamer)
|
334 |
)
|
335 |
generation_thread.start()
|
336 |
|
337 |
-
# Wait for generation to start
|
338 |
time.sleep(1)
|
339 |
|
340 |
-
#
|
341 |
-
|
342 |
-
|
343 |
-
generation_thread.join(timeout=5.0)
|
344 |
-
self.is_generating = False
|
345 |
-
yield None, "🛑 Generation stopped by user"
|
346 |
-
return
|
347 |
-
|
348 |
-
# Get the audio stream
|
349 |
-
audio_output_stream = audio_streamer.get_stream(0)
|
350 |
|
351 |
all_audio_chunks = []
|
352 |
-
pending_chunks = []
|
353 |
chunk_count = 0
|
354 |
-
last_yield_time = time.time()
|
355 |
-
min_yield_interval = 15
|
356 |
-
min_chunk_size = target_sr * 30
|
357 |
-
|
358 |
-
has_yielded_audio = False
|
359 |
-
has_received_chunks = False
|
360 |
|
361 |
-
for audio_chunk in
|
362 |
if self.stop_generation:
|
363 |
audio_streamer.end()
|
364 |
break
|
365 |
-
|
366 |
chunk_count += 1
|
367 |
-
has_received_chunks = True
|
368 |
|
|
|
369 |
if torch.is_tensor(audio_chunk):
|
370 |
if audio_chunk.dtype == torch.bfloat16:
|
371 |
audio_chunk = audio_chunk.float()
|
@@ -373,87 +349,43 @@ class VibeVoiceChat:
|
|
373 |
else:
|
374 |
audio_np = np.array(audio_chunk, dtype=np.float32)
|
375 |
|
|
|
376 |
if len(audio_np.shape) > 1:
|
377 |
audio_np = audio_np.squeeze()
|
378 |
|
379 |
-
|
|
|
380 |
all_audio_chunks.append(audio_16bit)
|
381 |
-
pending_chunks.append(audio_16bit)
|
382 |
-
|
383 |
-
pending_audio_size = sum(len(chunk) for chunk in pending_chunks)
|
384 |
-
current_time = time.time()
|
385 |
-
time_since_last_yield = current_time - last_yield_time
|
386 |
|
387 |
-
|
388 |
-
if
|
389 |
-
|
390 |
-
|
391 |
-
elif has_yielded_audio and (pending_audio_size >= min_chunk_size or time_since_last_yield >= min_yield_interval):
|
392 |
-
should_yield = True
|
393 |
-
|
394 |
-
if should_yield and pending_chunks:
|
395 |
-
new_audio = np.concatenate(pending_chunks)
|
396 |
-
total_duration = sum(len(chunk) for chunk in all_audio_chunks) / target_sr
|
397 |
-
|
398 |
-
log_update = f"🎵 Streaming: {total_duration:.1f}s generated (chunk {chunk_count})"
|
399 |
-
yield (target_sr, new_audio), log_update
|
400 |
-
|
401 |
-
pending_chunks = []
|
402 |
-
last_yield_time = current_time
|
403 |
-
|
404 |
-
# Yield any remaining chunks
|
405 |
-
if pending_chunks:
|
406 |
-
final_new_audio = np.concatenate(pending_chunks)
|
407 |
-
total_duration = sum(len(chunk) for chunk in all_audio_chunks) / target_sr
|
408 |
-
log_update = f"🎵 Streaming final chunk: {total_duration:.1f}s total"
|
409 |
-
yield (target_sr, final_new_audio), log_update
|
410 |
-
has_yielded_audio = True
|
411 |
|
412 |
# Wait for generation to complete
|
413 |
generation_thread.join(timeout=5.0)
|
414 |
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
self.current_streamer = None
|
421 |
self.is_generating = False
|
422 |
|
423 |
-
generation_time = time.time() - start_time
|
424 |
-
|
425 |
-
if self.stop_generation:
|
426 |
-
yield None, "🛑 Generation stopped by user"
|
427 |
-
return
|
428 |
-
|
429 |
-
if not has_received_chunks:
|
430 |
-
yield None, f"❌ Error: No audio chunks were received. Generation time: {generation_time:.2f}s"
|
431 |
-
return
|
432 |
-
|
433 |
-
if not has_yielded_audio:
|
434 |
-
yield None, f"❌ Error: Audio was generated but not streamed. Chunk count: {chunk_count}"
|
435 |
-
return
|
436 |
-
|
437 |
-
if all_audio_chunks:
|
438 |
-
complete_audio = np.concatenate(all_audio_chunks)
|
439 |
-
final_duration = len(complete_audio) / target_sr
|
440 |
-
|
441 |
-
final_log = f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
|
442 |
-
final_log += f"🎵 Final audio duration: {final_duration:.2f} seconds\n"
|
443 |
-
final_log += f"📊 Total chunks: {chunk_count}\n"
|
444 |
-
final_log += "✨ Generation successful!"
|
445 |
-
|
446 |
-
yield None, final_log
|
447 |
-
else:
|
448 |
-
yield None, "❌ No audio was generated."
|
449 |
-
|
450 |
except Exception as e:
|
451 |
print(f"Error in generation: {e}")
|
452 |
import traceback
|
453 |
traceback.print_exc()
|
454 |
self.is_generating = False
|
455 |
self.current_streamer = None
|
456 |
-
yield None
|
457 |
|
458 |
def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
|
459 |
"""Helper method to run generation with streamer."""
|
@@ -461,9 +393,10 @@ class VibeVoiceChat:
|
|
461 |
def check_stop():
|
462 |
return self.stop_generation
|
463 |
|
|
|
464 |
if self.device == "cuda" and torch.cuda.is_available():
|
465 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
466 |
-
self.model.generate(
|
467 |
**inputs,
|
468 |
max_new_tokens=None,
|
469 |
cfg_scale=cfg_scale,
|
@@ -475,7 +408,7 @@ class VibeVoiceChat:
|
|
475 |
refresh_negative=True,
|
476 |
)
|
477 |
else:
|
478 |
-
self.model.generate(
|
479 |
**inputs,
|
480 |
max_new_tokens=None,
|
481 |
cfg_scale=cfg_scale,
|
@@ -490,25 +423,35 @@ class VibeVoiceChat:
|
|
490 |
print(f"Error in generation thread: {e}")
|
491 |
import traceback
|
492 |
traceback.print_exc()
|
493 |
-
finally:
|
494 |
audio_streamer.end()
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
def stop_audio_generation(self):
|
497 |
-
"""
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
self.current_streamer.end()
|
505 |
-
except Exception as e:
|
506 |
-
print(f"Error ending streamer: {e}")
|
507 |
|
508 |
|
509 |
def create_chat_interface(chat_instance: VibeVoiceChat):
|
510 |
-
"""Create a Gradio ChatInterface for VibeVoice
|
511 |
|
|
|
512 |
voice_options = list(chat_instance.available_voices.keys())
|
513 |
if not voice_options:
|
514 |
voice_options = ["Default"]
|
@@ -516,40 +459,51 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
|
|
516 |
default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
|
517 |
default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
|
518 |
|
519 |
-
#
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
@keyframes pulse {
|
532 |
-
0% { opacity: 1; transform: scale(1); }
|
533 |
-
50% { opacity: 0.5; transform: scale(1.1); }
|
534 |
-
100% { opacity: 1; transform: scale(1); }
|
535 |
-
}
|
536 |
-
|
537 |
-
.streaming-status {
|
538 |
-
background: linear-gradient(135deg, #dcfce7 0%, #bbf7d0 100%);
|
539 |
-
border: 1px solid rgba(34, 197, 94, 0.3);
|
540 |
-
border-radius: 8px;
|
541 |
-
padding: 0.75rem;
|
542 |
-
margin: 0.5rem 0;
|
543 |
-
text-align: center;
|
544 |
-
font-size: 0.9rem;
|
545 |
-
color: #166534;
|
546 |
-
}
|
547 |
-
"""
|
548 |
-
|
549 |
-
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"),
|
550 |
-
css=custom_css, fill_height=True) as interface:
|
551 |
|
552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
|
554 |
with gr.Row():
|
555 |
with gr.Column(scale=1):
|
@@ -601,33 +555,18 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
|
|
601 |
lines=3
|
602 |
)
|
603 |
|
604 |
-
# Streaming status indicator
|
605 |
-
streaming_status = gr.HTML(
|
606 |
-
value="""
|
607 |
-
<div class="streaming-status">
|
608 |
-
<span class="streaming-indicator"></span>
|
609 |
-
<strong>LIVE STREAMING</strong> - Audio is being generated in real-time
|
610 |
-
</div>
|
611 |
-
""",
|
612 |
-
visible=False,
|
613 |
-
elem_id="streaming-status"
|
614 |
-
)
|
615 |
-
|
616 |
-
# Audio output with streaming enabled
|
617 |
audio_output = gr.Audio(
|
618 |
label="Generated Audio",
|
619 |
-
type="
|
620 |
-
streaming=True,
|
621 |
autoplay=True,
|
622 |
-
show_download_button=False,
|
623 |
visible=False
|
624 |
)
|
625 |
|
626 |
with gr.Row():
|
627 |
-
|
628 |
-
|
629 |
-
clear_btn = gr.Button("🗑️ Clear")
|
630 |
|
|
|
631 |
gr.Examples(
|
632 |
examples=[
|
633 |
"Hello! How are you doing today?",
|
@@ -639,90 +578,39 @@ def create_chat_interface(chat_instance: VibeVoiceChat):
|
|
639 |
label="Example Messages"
|
640 |
)
|
641 |
|
642 |
-
|
643 |
-
|
|
|
|
|
644 |
history = history or []
|
645 |
-
history.append({"role": "user", "content":
|
646 |
|
647 |
-
#
|
648 |
-
|
649 |
-
|
650 |
-
# Generate audio stream
|
651 |
-
for audio_chunk, log_message in chat_instance.generate_audio_stream(
|
652 |
-
message_text, history, voice_1, voice_2, num_speakers, cfg_scale
|
653 |
-
):
|
654 |
-
if audio_chunk is not None:
|
655 |
-
# Streaming audio chunk received
|
656 |
-
yield history, gr.update(interactive=False), audio_chunk, gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
|
657 |
-
else:
|
658 |
-
# Final status message or error
|
659 |
-
if log_message and "❌" in log_message:
|
660 |
-
# Error case
|
661 |
-
history.append({"role": "assistant", "content": log_message})
|
662 |
-
yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
663 |
-
elif log_message and "🛑" in log_message:
|
664 |
-
# Stopped case
|
665 |
-
history.append({"role": "assistant", "content": log_message})
|
666 |
-
yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
667 |
-
elif log_message:
|
668 |
-
# Final success message
|
669 |
-
history.append({"role": "assistant", "content": "🎵 Audio generated successfully!"})
|
670 |
-
yield history, gr.update(interactive=True), None, gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
671 |
|
672 |
-
#
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
|
680 |
-
|
681 |
-
|
682 |
-
return None
|
683 |
-
|
684 |
-
# Event handlers
|
685 |
-
submit_btn.click(
|
686 |
-
fn=clear_audio_outputs,
|
687 |
-
inputs=[],
|
688 |
-
outputs=[audio_output],
|
689 |
-
queue=False
|
690 |
-
).then(
|
691 |
-
fn=process_and_display_stream,
|
692 |
inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
|
693 |
-
outputs=[chatbot,
|
694 |
queue=True
|
695 |
)
|
696 |
|
697 |
msg.submit(
|
698 |
-
fn=
|
699 |
-
inputs=[],
|
700 |
-
outputs=[audio_output],
|
701 |
-
queue=False
|
702 |
-
).then(
|
703 |
-
fn=process_and_display_stream,
|
704 |
inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
|
705 |
-
outputs=[chatbot,
|
706 |
queue=True
|
707 |
)
|
708 |
|
709 |
-
|
710 |
-
fn=stop_generation_handler,
|
711 |
-
inputs=[],
|
712 |
-
outputs=[streaming_status, submit_btn, stop_btn],
|
713 |
-
queue=False
|
714 |
-
).then(
|
715 |
-
fn=lambda: None,
|
716 |
-
inputs=[],
|
717 |
-
outputs=[audio_output],
|
718 |
-
queue=False
|
719 |
-
)
|
720 |
-
|
721 |
-
clear_btn.click(
|
722 |
-
lambda: ([], gr.update(value="", interactive=True), gr.update(visible=False, value=None),
|
723 |
-
gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)),
|
724 |
-
outputs=[chatbot, msg, audio_output, streaming_status, submit_btn, stop_btn]
|
725 |
-
)
|
726 |
|
727 |
return interface
|
728 |
|
@@ -777,6 +665,9 @@ def main():
|
|
777 |
|
778 |
if chat_instance.device == "cpu":
|
779 |
print("\n⚠️ WARNING: Running on CPU - generation will be VERY slow!")
|
|
|
|
|
|
|
780 |
|
781 |
# Launch the interface
|
782 |
interface.queue(max_size=10).launch(
|
|
|
1 |
+
"""
|
2 |
+
VibeVoice Simple Chat Interface - Streamlined Audio Generation Demo
|
3 |
+
"""
|
4 |
+
|
5 |
import argparse
|
6 |
import os
|
7 |
import tempfile
|
|
|
92 |
logger = logging.get_logger(__name__)
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
class VibeVoiceChat:
|
96 |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
|
97 |
"""Initialize the VibeVoice chat model."""
|
|
|
175 |
|
176 |
def setup_voice_presets(self):
|
177 |
"""Setup voice presets from the voices directory."""
|
178 |
+
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
|
|
179 |
|
180 |
# Create voices directory if it doesn't exist
|
181 |
if not os.path.exists(voices_dir):
|
|
|
206 |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
|
207 |
"""Read and preprocess audio file."""
|
208 |
try:
|
209 |
+
wav, sr = sf.read(audio_path)
|
210 |
+
if len(wav.shape) > 1:
|
211 |
+
wav = np.mean(wav, axis=1)
|
212 |
if sr != target_sr:
|
213 |
+
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
|
214 |
return wav
|
215 |
except Exception as e:
|
216 |
print(f"Error reading audio {audio_path}: {e}")
|
217 |
+
return np.zeros(24000) # Return 1 second of silence as fallback
|
218 |
|
219 |
def format_script(self, message: str, num_speakers: int = 2) -> str:
|
220 |
"""Format input message into a script with speaker assignments."""
|
|
|
226 |
if not line:
|
227 |
continue
|
228 |
|
229 |
+
# Check if already formatted
|
230 |
if line.startswith('Speaker ') and ':' in line:
|
231 |
formatted_lines.append(line)
|
232 |
else:
|
|
|
245 |
num_speakers: int,
|
246 |
cfg_scale: float
|
247 |
) -> Iterator[tuple]:
|
248 |
+
"""Generate audio stream from text input."""
|
|
|
|
|
|
|
249 |
try:
|
250 |
self.stop_generation = False
|
251 |
self.is_generating = True
|
252 |
|
253 |
# Validate inputs
|
254 |
if not message.strip():
|
255 |
+
yield None
|
|
|
256 |
return
|
257 |
|
258 |
# Format the script
|
|
|
260 |
print(f"Formatted script:\n{formatted_script}")
|
261 |
print(f"Using device: {self.device}")
|
262 |
|
263 |
+
# Start timing
|
264 |
start_time = time.time()
|
265 |
|
266 |
# Select voices based on number of speakers
|
|
|
272 |
|
273 |
# Load voice samples
|
274 |
voice_samples = []
|
|
|
275 |
for i in range(num_speakers):
|
276 |
+
# Use the appropriate voice for each speaker
|
277 |
if i < len(selected_voices):
|
278 |
voice_name = selected_voices[i]
|
279 |
if voice_name in self.available_voices and self.available_voices[voice_name]:
|
280 |
+
audio_data = self.read_audio(self.available_voices[voice_name])
|
281 |
else:
|
282 |
+
audio_data = np.zeros(24000) # Default silence
|
283 |
else:
|
284 |
+
# Use first voice or default if not enough voices selected
|
285 |
if selected_voices and selected_voices[0] in self.available_voices and self.available_voices[selected_voices[0]]:
|
286 |
+
audio_data = self.read_audio(self.available_voices[selected_voices[0]])
|
287 |
else:
|
288 |
+
audio_data = np.zeros(24000) # Default silence
|
289 |
|
290 |
voice_samples.append(audio_data)
|
291 |
|
|
|
304 |
if self.device == "cuda":
|
305 |
inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
|
306 |
print(f"✓ Inputs moved to GPU")
|
307 |
+
# Check GPU memory
|
308 |
+
if torch.cuda.is_available():
|
309 |
+
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
310 |
|
311 |
# Create audio streamer
|
312 |
audio_streamer = AudioStreamer(
|
|
|
317 |
|
318 |
self.current_streamer = audio_streamer
|
319 |
|
320 |
+
# Start generation in separate thread
|
321 |
generation_thread = threading.Thread(
|
322 |
target=self._generate_with_streamer,
|
323 |
args=(inputs, cfg_scale, audio_streamer)
|
324 |
)
|
325 |
generation_thread.start()
|
326 |
|
327 |
+
# Wait briefly for generation to start
|
328 |
time.sleep(1)
|
329 |
|
330 |
+
# Stream audio chunks
|
331 |
+
sample_rate = 24000
|
332 |
+
audio_stream = audio_streamer.get_stream(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
all_audio_chunks = []
|
|
|
335 |
chunk_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
+
for audio_chunk in audio_stream:
|
338 |
if self.stop_generation:
|
339 |
audio_streamer.end()
|
340 |
break
|
341 |
+
|
342 |
chunk_count += 1
|
|
|
343 |
|
344 |
+
# Convert to numpy
|
345 |
if torch.is_tensor(audio_chunk):
|
346 |
if audio_chunk.dtype == torch.bfloat16:
|
347 |
audio_chunk = audio_chunk.float()
|
|
|
349 |
else:
|
350 |
audio_np = np.array(audio_chunk, dtype=np.float32)
|
351 |
|
352 |
+
# Ensure 1D
|
353 |
if len(audio_np.shape) > 1:
|
354 |
audio_np = audio_np.squeeze()
|
355 |
|
356 |
+
# Convert to 16-bit
|
357 |
+
audio_16bit = self.convert_to_16_bit_wav(audio_np)
|
358 |
all_audio_chunks.append(audio_16bit)
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
+
# Yield accumulated audio
|
361 |
+
if all_audio_chunks:
|
362 |
+
complete_audio = np.concatenate(all_audio_chunks)
|
363 |
+
yield (sample_rate, complete_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
# Wait for generation to complete
|
366 |
generation_thread.join(timeout=5.0)
|
367 |
|
368 |
+
# Final yield with complete audio
|
369 |
+
if all_audio_chunks:
|
370 |
+
complete_audio = np.concatenate(all_audio_chunks)
|
371 |
+
generation_time = time.time() - start_time
|
372 |
+
audio_duration = len(complete_audio) / sample_rate
|
373 |
+
print(f"✓ Generation complete:")
|
374 |
+
print(f" Time taken: {generation_time:.2f} seconds")
|
375 |
+
print(f" Audio duration: {audio_duration:.2f} seconds")
|
376 |
+
print(f" Real-time factor: {audio_duration/generation_time:.2f}x")
|
377 |
+
yield (sample_rate, complete_audio)
|
378 |
+
|
379 |
self.current_streamer = None
|
380 |
self.is_generating = False
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
except Exception as e:
|
383 |
print(f"Error in generation: {e}")
|
384 |
import traceback
|
385 |
traceback.print_exc()
|
386 |
self.is_generating = False
|
387 |
self.current_streamer = None
|
388 |
+
yield None
|
389 |
|
390 |
def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer):
|
391 |
"""Helper method to run generation with streamer."""
|
|
|
393 |
def check_stop():
|
394 |
return self.stop_generation
|
395 |
|
396 |
+
# Use torch.cuda.amp for mixed precision if available
|
397 |
if self.device == "cuda" and torch.cuda.is_available():
|
398 |
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
399 |
+
outputs = self.model.generate(
|
400 |
**inputs,
|
401 |
max_new_tokens=None,
|
402 |
cfg_scale=cfg_scale,
|
|
|
408 |
refresh_negative=True,
|
409 |
)
|
410 |
else:
|
411 |
+
outputs = self.model.generate(
|
412 |
**inputs,
|
413 |
max_new_tokens=None,
|
414 |
cfg_scale=cfg_scale,
|
|
|
423 |
print(f"Error in generation thread: {e}")
|
424 |
import traceback
|
425 |
traceback.print_exc()
|
|
|
426 |
audio_streamer.end()
|
427 |
|
428 |
+
def convert_to_16_bit_wav(self, data):
|
429 |
+
"""Convert audio data to 16-bit WAV format."""
|
430 |
+
if torch.is_tensor(data):
|
431 |
+
data = data.detach().cpu().numpy()
|
432 |
+
|
433 |
+
data = np.array(data)
|
434 |
+
|
435 |
+
if np.max(np.abs(data)) > 1.0:
|
436 |
+
data = data / np.max(np.abs(data))
|
437 |
+
|
438 |
+
data = (data * 32767).astype(np.int16)
|
439 |
+
return data
|
440 |
+
|
441 |
def stop_audio_generation(self):
|
442 |
+
"""Stop the current audio generation."""
|
443 |
+
self.stop_generation = True
|
444 |
+
if self.current_streamer:
|
445 |
+
try:
|
446 |
+
self.current_streamer.end()
|
447 |
+
except:
|
448 |
+
pass
|
|
|
|
|
|
|
449 |
|
450 |
|
451 |
def create_chat_interface(chat_instance: VibeVoiceChat):
|
452 |
+
"""Create a simplified Gradio ChatInterface for VibeVoice."""
|
453 |
|
454 |
+
# Get available voices
|
455 |
voice_options = list(chat_instance.available_voices.keys())
|
456 |
if not voice_options:
|
457 |
voice_options = ["Default"]
|
|
|
459 |
default_voice_1 = voice_options[0] if len(voice_options) > 0 else "Default"
|
460 |
default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0]
|
461 |
|
462 |
+
# Define the chat function that returns audio
|
463 |
+
def chat_fn(message: str, history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float):
|
464 |
+
"""Process chat message and generate audio response."""
|
465 |
+
|
466 |
+
# Extract text from message
|
467 |
+
if isinstance(message, dict):
|
468 |
+
text = message.get("text", "")
|
469 |
+
else:
|
470 |
+
text = message
|
471 |
+
|
472 |
+
if not text.strip():
|
473 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
|
475 |
+
try:
|
476 |
+
# Generate audio stream
|
477 |
+
audio_generator = chat_instance.generate_audio_stream(
|
478 |
+
text, history, voice_1, voice_2, num_speakers, cfg_scale
|
479 |
+
)
|
480 |
+
|
481 |
+
# Collect all audio data
|
482 |
+
audio_data = None
|
483 |
+
for audio_chunk in audio_generator:
|
484 |
+
if audio_chunk is not None:
|
485 |
+
audio_data = audio_chunk
|
486 |
+
|
487 |
+
# Return audio file path or error message
|
488 |
+
if audio_data:
|
489 |
+
# Save audio to temporary file
|
490 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
491 |
+
sample_rate, audio_array = audio_data
|
492 |
+
sf.write(tmp_file.name, audio_array, sample_rate)
|
493 |
+
# Return the file path directly
|
494 |
+
return tmp_file.name
|
495 |
+
else:
|
496 |
+
return "Failed to generate audio"
|
497 |
+
|
498 |
+
except Exception as e:
|
499 |
+
print(f"Error in chat_fn: {e}")
|
500 |
+
import traceback
|
501 |
+
traceback.print_exc()
|
502 |
+
return f"Error: {str(e)}"
|
503 |
+
|
504 |
+
# Create the interface using Blocks for more control
|
505 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple"), fill_height=True) as interface:
|
506 |
+
gr.Markdown("# 🎙️ VibeVoice Chat\nGenerate natural dialogue audio with AI voices")
|
507 |
|
508 |
with gr.Row():
|
509 |
with gr.Column(scale=1):
|
|
|
555 |
lines=3
|
556 |
)
|
557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
audio_output = gr.Audio(
|
559 |
label="Generated Audio",
|
560 |
+
type="filepath",
|
|
|
561 |
autoplay=True,
|
|
|
562 |
visible=False
|
563 |
)
|
564 |
|
565 |
with gr.Row():
|
566 |
+
submit = gr.Button("🎵 Generate Audio", variant="primary")
|
567 |
+
clear = gr.Button("🗑️ Clear")
|
|
|
568 |
|
569 |
+
# Example messages
|
570 |
gr.Examples(
|
571 |
examples=[
|
572 |
"Hello! How are you doing today?",
|
|
|
578 |
label="Example Messages"
|
579 |
)
|
580 |
|
581 |
+
# Set up event handlers
|
582 |
+
def process_and_display(message, history, voice_1, voice_2, num_speakers, cfg_scale):
|
583 |
+
"""Process message and update both chatbot and audio."""
|
584 |
+
# Add user message to history
|
585 |
history = history or []
|
586 |
+
history.append({"role": "user", "content": message})
|
587 |
|
588 |
+
# Generate audio
|
589 |
+
audio_path = chat_fn(message, history, voice_1, voice_2, num_speakers, cfg_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
590 |
|
591 |
+
# Add assistant response with audio
|
592 |
+
if audio_path and audio_path.endswith('.wav'):
|
593 |
+
history.append({"role": "assistant", "content": f"🎵 Audio generated successfully"})
|
594 |
+
return history, audio_path, gr.update(visible=True), ""
|
595 |
+
else:
|
596 |
+
history.append({"role": "assistant", "content": audio_path or "Failed to generate audio"})
|
597 |
+
return history, None, gr.update(visible=False), ""
|
598 |
|
599 |
+
submit.click(
|
600 |
+
fn=process_and_display,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
601 |
inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
|
602 |
+
outputs=[chatbot, audio_output, audio_output, msg],
|
603 |
queue=True
|
604 |
)
|
605 |
|
606 |
msg.submit(
|
607 |
+
fn=process_and_display,
|
|
|
|
|
|
|
|
|
|
|
608 |
inputs=[msg, chatbot, voice_1, voice_2, num_speakers, cfg_scale],
|
609 |
+
outputs=[chatbot, audio_output, audio_output, msg],
|
610 |
queue=True
|
611 |
)
|
612 |
|
613 |
+
clear.click(lambda: ([], None, gr.update(visible=False)), outputs=[chatbot, audio_output, audio_output])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
|
615 |
return interface
|
616 |
|
|
|
665 |
|
666 |
if chat_instance.device == "cpu":
|
667 |
print("\n⚠️ WARNING: Running on CPU - generation will be VERY slow!")
|
668 |
+
print(" For faster generation, ensure you have:")
|
669 |
+
print(" 1. NVIDIA GPU with CUDA support")
|
670 |
+
print(" 2. PyTorch with CUDA installed: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118")
|
671 |
|
672 |
# Launch the interface
|
673 |
interface.queue(max_size=10).launch(
|