|
""" |
|
VibeVoice Gradio Demo - High-Quality Dialogue Generation Interface with Streaming Support |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
import sys |
|
import tempfile |
|
import time |
|
from pathlib import Path |
|
from typing import List, Dict, Any, Iterator |
|
from datetime import datetime |
|
import threading |
|
import numpy as np |
|
import gradio as gr |
|
import librosa |
|
import soundfile as sf |
|
import torch |
|
import os |
|
import traceback |
|
|
|
from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig |
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference |
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor |
|
from vibevoice.modular.streamer import AudioStreamer |
|
from transformers.utils import logging |
|
from transformers import set_seed |
|
|
|
logging.set_verbosity_info() |
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class VibeVoiceDemo: |
|
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): |
|
"""Initialize the VibeVoice demo with model loading.""" |
|
self.model_path = model_path |
|
self.device = device |
|
self.inference_steps = inference_steps |
|
self.is_generating = False |
|
self.stop_generation = False |
|
self.current_streamer = None |
|
self.load_model() |
|
self.setup_voice_presets() |
|
self.load_example_scripts() |
|
|
|
def load_model(self): |
|
"""Load the VibeVoice model and processor.""" |
|
print(f"Loading processor & model from {self.model_path}") |
|
|
|
|
|
self.processor = VibeVoiceProcessor.from_pretrained( |
|
self.model_path, |
|
) |
|
|
|
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
|
self.model_path, |
|
torch_dtype=torch.bfloat16, |
|
device_map='cuda', |
|
attn_implementation="flash_attention_2", |
|
) |
|
self.model.eval() |
|
|
|
|
|
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( |
|
self.model.model.noise_scheduler.config, |
|
algorithm_type='sde-dpmsolver++', |
|
beta_schedule='squaredcos_cap_v2' |
|
) |
|
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) |
|
|
|
if hasattr(self.model.model, 'language_model'): |
|
print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}") |
|
|
|
def setup_voice_presets(self): |
|
"""Setup voice presets by scanning the voices directory.""" |
|
voices_dir = os.path.join(os.path.dirname(__file__), "voices") |
|
|
|
|
|
if not os.path.exists(voices_dir): |
|
print(f"Warning: Voices directory not found at {voices_dir}") |
|
self.voice_presets = {} |
|
self.available_voices = {} |
|
return |
|
|
|
|
|
self.voice_presets = {} |
|
|
|
|
|
wav_files = [f for f in os.listdir(voices_dir) |
|
if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))] |
|
|
|
|
|
for wav_file in wav_files: |
|
|
|
name = os.path.splitext(wav_file)[0] |
|
|
|
full_path = os.path.join(voices_dir, wav_file) |
|
self.voice_presets[name] = full_path |
|
|
|
|
|
self.voice_presets = dict(sorted(self.voice_presets.items())) |
|
|
|
|
|
self.available_voices = { |
|
name: path for name, path in self.voice_presets.items() |
|
if os.path.exists(path) |
|
} |
|
|
|
if not self.available_voices: |
|
raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.") |
|
|
|
print(f"Found {len(self.available_voices)} voice files in {voices_dir}") |
|
print(f"Available voices: {', '.join(self.available_voices.keys())}") |
|
|
|
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: |
|
"""Read and preprocess audio file.""" |
|
try: |
|
wav, sr = sf.read(audio_path) |
|
if len(wav.shape) > 1: |
|
wav = np.mean(wav, axis=1) |
|
if sr != target_sr: |
|
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) |
|
return wav |
|
except Exception as e: |
|
print(f"Error reading audio {audio_path}: {e}") |
|
return np.array([]) |
|
|
|
def generate_podcast_streaming(self, |
|
num_speakers: int, |
|
script: str, |
|
speaker_1: str = None, |
|
speaker_2: str = None, |
|
speaker_3: str = None, |
|
speaker_4: str = None, |
|
cfg_scale: float = 1.3) -> Iterator[tuple]: |
|
try: |
|
|
|
self.stop_generation = False |
|
self.is_generating = True |
|
|
|
|
|
if not script.strip(): |
|
self.is_generating = False |
|
raise gr.Error("Error: Please provide a script.") |
|
|
|
if num_speakers < 1 or num_speakers > 4: |
|
self.is_generating = False |
|
raise gr.Error("Error: Number of speakers must be between 1 and 4.") |
|
|
|
|
|
selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] |
|
|
|
|
|
for i, speaker in enumerate(selected_speakers): |
|
if not speaker or speaker not in self.available_voices: |
|
self.is_generating = False |
|
raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.") |
|
|
|
|
|
log = f"ποΈ Generating podcast with {num_speakers} speakers\n" |
|
log += f"π Parameters: CFG Scale={cfg_scale}, Inference Steps={self.inference_steps}\n" |
|
log += f"π Speakers: {', '.join(selected_speakers)}\n" |
|
|
|
|
|
if self.stop_generation: |
|
self.is_generating = False |
|
yield None, "π Generation stopped by user", gr.update(visible=False) |
|
return |
|
|
|
|
|
voice_samples = [] |
|
for speaker_name in selected_speakers: |
|
audio_path = self.available_voices[speaker_name] |
|
audio_data = self.read_audio(audio_path) |
|
if len(audio_data) == 0: |
|
self.is_generating = False |
|
raise gr.Error(f"Error: Failed to load audio for {speaker_name}") |
|
voice_samples.append(audio_data) |
|
|
|
|
|
|
|
|
|
if self.stop_generation: |
|
self.is_generating = False |
|
yield None, "π Generation stopped by user", gr.update(visible=False) |
|
return |
|
|
|
|
|
lines = script.strip().split('\n') |
|
formatted_script_lines = [] |
|
|
|
for line in lines: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
|
|
|
|
if line.startswith('Speaker ') and ':' in line: |
|
formatted_script_lines.append(line) |
|
else: |
|
|
|
speaker_id = len(formatted_script_lines) % num_speakers |
|
formatted_script_lines.append(f"Speaker {speaker_id}: {line}") |
|
|
|
formatted_script = '\n'.join(formatted_script_lines) |
|
log += f"π Formatted script with {len(formatted_script_lines)} turns\n\n" |
|
log += "π Processing with VibeVoice (streaming mode)...\n" |
|
|
|
|
|
if self.stop_generation: |
|
self.is_generating = False |
|
yield None, "π Generation stopped by user", gr.update(visible=False) |
|
return |
|
|
|
start_time = time.time() |
|
|
|
inputs = self.processor( |
|
text=[formatted_script], |
|
voice_samples=[voice_samples], |
|
padding=True, |
|
return_tensors="pt", |
|
return_attention_mask=True, |
|
) |
|
|
|
|
|
audio_streamer = AudioStreamer( |
|
batch_size=1, |
|
stop_signal=None, |
|
timeout=None |
|
) |
|
|
|
|
|
self.current_streamer = audio_streamer |
|
|
|
|
|
generation_thread = threading.Thread( |
|
target=self._generate_with_streamer, |
|
args=(inputs, cfg_scale, audio_streamer) |
|
) |
|
generation_thread.start() |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
if self.stop_generation: |
|
audio_streamer.end() |
|
generation_thread.join(timeout=5.0) |
|
self.is_generating = False |
|
yield None, "π Generation stopped by user", gr.update(visible=False) |
|
return |
|
|
|
|
|
sample_rate = 24000 |
|
all_audio_chunks = [] |
|
pending_chunks = [] |
|
chunk_count = 0 |
|
last_yield_time = time.time() |
|
min_yield_interval = 15 |
|
min_chunk_size = sample_rate * 30 |
|
|
|
|
|
audio_stream = audio_streamer.get_stream(0) |
|
|
|
has_yielded_audio = False |
|
has_received_chunks = False |
|
|
|
for audio_chunk in audio_stream: |
|
|
|
if self.stop_generation: |
|
audio_streamer.end() |
|
break |
|
|
|
chunk_count += 1 |
|
has_received_chunks = True |
|
|
|
|
|
if torch.is_tensor(audio_chunk): |
|
|
|
if audio_chunk.dtype == torch.bfloat16: |
|
audio_chunk = audio_chunk.float() |
|
audio_np = audio_chunk.cpu().numpy().astype(np.float32) |
|
else: |
|
audio_np = np.array(audio_chunk, dtype=np.float32) |
|
|
|
|
|
if len(audio_np.shape) > 1: |
|
audio_np = audio_np.squeeze() |
|
|
|
|
|
audio_16bit = convert_to_16_bit_wav(audio_np) |
|
|
|
|
|
all_audio_chunks.append(audio_16bit) |
|
|
|
|
|
pending_chunks.append(audio_16bit) |
|
|
|
|
|
pending_audio_size = sum(len(chunk) for chunk in pending_chunks) |
|
current_time = time.time() |
|
time_since_last_yield = current_time - last_yield_time |
|
|
|
|
|
should_yield = False |
|
if not has_yielded_audio and pending_audio_size >= min_chunk_size: |
|
|
|
should_yield = True |
|
has_yielded_audio = True |
|
elif has_yielded_audio and (pending_audio_size >= min_chunk_size or time_since_last_yield >= min_yield_interval): |
|
|
|
should_yield = True |
|
|
|
if should_yield and pending_chunks: |
|
|
|
new_audio = np.concatenate(pending_chunks) |
|
new_duration = len(new_audio) / sample_rate |
|
total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate |
|
|
|
log_update = log + f"π΅ Streaming: {total_duration:.1f}s generated (chunk {chunk_count})\n" |
|
|
|
|
|
yield (sample_rate, new_audio), None, log_update, gr.update(visible=True) |
|
|
|
|
|
pending_chunks = [] |
|
last_yield_time = current_time |
|
|
|
|
|
if pending_chunks: |
|
final_new_audio = np.concatenate(pending_chunks) |
|
total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate |
|
log_update = log + f"π΅ Streaming final chunk: {total_duration:.1f}s total\n" |
|
yield (sample_rate, final_new_audio), None, log_update, gr.update(visible=True) |
|
has_yielded_audio = True |
|
|
|
|
|
generation_thread.join(timeout=5.0) |
|
|
|
|
|
if generation_thread.is_alive(): |
|
print("Warning: Generation thread did not complete within timeout") |
|
audio_streamer.end() |
|
generation_thread.join(timeout=5.0) |
|
|
|
|
|
self.current_streamer = None |
|
self.is_generating = False |
|
|
|
generation_time = time.time() - start_time |
|
|
|
|
|
if self.stop_generation: |
|
yield None, None, "π Generation stopped by user", gr.update(visible=False) |
|
return |
|
|
|
|
|
|
|
|
|
|
|
if has_received_chunks and not has_yielded_audio and all_audio_chunks: |
|
|
|
complete_audio = np.concatenate(all_audio_chunks) |
|
final_duration = len(complete_audio) / sample_rate |
|
|
|
final_log = log + f"β±οΈ Generation completed in {generation_time:.2f} seconds\n" |
|
final_log += f"π΅ Final audio duration: {final_duration:.2f} seconds\n" |
|
final_log += f"π Total chunks: {chunk_count}\n" |
|
final_log += "β¨ Generation successful! Complete audio is ready.\n" |
|
final_log += "π‘ Not satisfied? You can regenerate or adjust the CFG scale for different results." |
|
|
|
|
|
yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False) |
|
return |
|
|
|
if not has_received_chunks: |
|
error_log = log + f"\nβ Error: No audio chunks were received from the model. Generation time: {generation_time:.2f}s" |
|
yield None, None, error_log, gr.update(visible=False) |
|
return |
|
|
|
if not has_yielded_audio: |
|
error_log = log + f"\nβ Error: Audio was generated but not streamed. Chunk count: {chunk_count}" |
|
yield None, None, error_log, gr.update(visible=False) |
|
return |
|
|
|
|
|
if all_audio_chunks: |
|
complete_audio = np.concatenate(all_audio_chunks) |
|
final_duration = len(complete_audio) / sample_rate |
|
|
|
final_log = log + f"β±οΈ Generation completed in {generation_time:.2f} seconds\n" |
|
final_log += f"π΅ Final audio duration: {final_duration:.2f} seconds\n" |
|
final_log += f"π Total chunks: {chunk_count}\n" |
|
final_log += "β¨ Generation successful! Complete audio is ready in the 'Complete Audio' tab.\n" |
|
final_log += "π‘ Not satisfied? You can regenerate or adjust the CFG scale for different results." |
|
|
|
|
|
yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False) |
|
else: |
|
final_log = log + "β No audio was generated." |
|
yield None, None, final_log, gr.update(visible=False) |
|
|
|
except gr.Error as e: |
|
|
|
self.is_generating = False |
|
self.current_streamer = None |
|
error_msg = f"β Input Error: {str(e)}" |
|
print(error_msg) |
|
yield None, None, error_msg, gr.update(visible=False) |
|
|
|
except Exception as e: |
|
self.is_generating = False |
|
self.current_streamer = None |
|
error_msg = f"β An unexpected error occurred: {str(e)}" |
|
print(error_msg) |
|
import traceback |
|
traceback.print_exc() |
|
yield None, None, error_msg, gr.update(visible=False) |
|
|
|
def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer): |
|
"""Helper method to run generation with streamer in a separate thread.""" |
|
try: |
|
|
|
if self.stop_generation: |
|
audio_streamer.end() |
|
return |
|
|
|
|
|
def check_stop_generation(): |
|
return self.stop_generation |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=None, |
|
cfg_scale=cfg_scale, |
|
tokenizer=self.processor.tokenizer, |
|
generation_config={ |
|
'do_sample': False, |
|
}, |
|
audio_streamer=audio_streamer, |
|
stop_check_fn=check_stop_generation, |
|
verbose=False, |
|
refresh_negative=True, |
|
) |
|
|
|
except Exception as e: |
|
print(f"Error in generation thread: {e}") |
|
traceback.print_exc() |
|
|
|
audio_streamer.end() |
|
|
|
def stop_audio_generation(self): |
|
"""Stop the current audio generation process.""" |
|
self.stop_generation = True |
|
if self.current_streamer is not None: |
|
try: |
|
self.current_streamer.end() |
|
except Exception as e: |
|
print(f"Error stopping streamer: {e}") |
|
print("π Audio generation stop requested") |
|
|
|
def load_example_scripts(self): |
|
"""Load example scripts from the text_examples directory.""" |
|
examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") |
|
self.example_scripts = [] |
|
|
|
|
|
if not os.path.exists(examples_dir): |
|
print(f"Warning: text_examples directory not found at {examples_dir}") |
|
return |
|
|
|
|
|
txt_files = sorted([f for f in os.listdir(examples_dir) |
|
if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))]) |
|
|
|
for txt_file in txt_files: |
|
file_path = os.path.join(examples_dir, txt_file) |
|
|
|
import re |
|
|
|
time_pattern = re.search(r'(\d+)min', txt_file.lower()) |
|
if time_pattern: |
|
minutes = int(time_pattern.group(1)) |
|
if minutes > 15: |
|
print(f"Skipping {txt_file}: duration {minutes} minutes exceeds 15-minute limit") |
|
continue |
|
|
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
script_content = f.read().strip() |
|
|
|
|
|
script_content = '\n'.join(line for line in script_content.split('\n') if line.strip()) |
|
|
|
if not script_content: |
|
continue |
|
|
|
|
|
num_speakers = self._get_num_speakers_from_script(script_content) |
|
|
|
|
|
self.example_scripts.append([num_speakers, script_content]) |
|
print(f"Loaded example: {txt_file} with {num_speakers} speakers") |
|
|
|
except Exception as e: |
|
print(f"Error loading example script {txt_file}: {e}") |
|
|
|
if self.example_scripts: |
|
print(f"Successfully loaded {len(self.example_scripts)} example scripts") |
|
else: |
|
print("No example scripts were loaded") |
|
|
|
def _get_num_speakers_from_script(self, script: str) -> int: |
|
"""Determine the number of unique speakers in a script.""" |
|
import re |
|
speakers = set() |
|
|
|
lines = script.strip().split('\n') |
|
for line in lines: |
|
|
|
match = re.match(r'^Speaker\s+(\d+)\s*:', line.strip(), re.IGNORECASE) |
|
if match: |
|
speaker_id = int(match.group(1)) |
|
speakers.add(speaker_id) |
|
|
|
|
|
if not speakers: |
|
return 1 |
|
|
|
|
|
|
|
max_speaker = max(speakers) |
|
min_speaker = min(speakers) |
|
|
|
if min_speaker == 0: |
|
return max_speaker + 1 |
|
else: |
|
|
|
return len(speakers) |
|
|
|
|
|
def create_demo_interface(demo_instance: VibeVoiceDemo): |
|
"""Create the Gradio interface with streaming support.""" |
|
|
|
|
|
custom_css = """ |
|
/* Modern light theme with gradients */ |
|
.gradio-container { |
|
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%); |
|
font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif; |
|
} |
|
|
|
/* Header styling */ |
|
.main-header { |
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
padding: 2rem; |
|
border-radius: 20px; |
|
margin-bottom: 2rem; |
|
text-align: center; |
|
box-shadow: 0 10px 40px rgba(102, 126, 234, 0.3); |
|
} |
|
|
|
.main-header h1 { |
|
color: white; |
|
font-size: 2.5rem; |
|
font-weight: 700; |
|
margin: 0; |
|
text-shadow: 0 2px 4px rgba(0,0,0,0.3); |
|
} |
|
|
|
.main-header p { |
|
color: rgba(255,255,255,0.9); |
|
font-size: 1.1rem; |
|
margin: 0.5rem 0 0 0; |
|
} |
|
|
|
/* Card styling */ |
|
.settings-card, .generation-card { |
|
background: rgba(255, 255, 255, 0.8); |
|
backdrop-filter: blur(10px); |
|
border: 1px solid rgba(226, 232, 240, 0.8); |
|
border-radius: 16px; |
|
padding: 1.5rem; |
|
margin-bottom: 1rem; |
|
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1); |
|
} |
|
|
|
/* Speaker selection styling */ |
|
.speaker-grid { |
|
display: grid; |
|
gap: 1rem; |
|
margin-bottom: 1rem; |
|
} |
|
|
|
.speaker-item { |
|
background: linear-gradient(135deg, #e2e8f0 0%, #cbd5e1 100%); |
|
border: 1px solid rgba(148, 163, 184, 0.4); |
|
border-radius: 12px; |
|
padding: 1rem; |
|
color: #374151; |
|
font-weight: 500; |
|
} |
|
|
|
/* Streaming indicator */ |
|
.streaming-indicator { |
|
display: inline-block; |
|
width: 10px; |
|
height: 10px; |
|
background: #22c55e; |
|
border-radius: 50%; |
|
margin-right: 8px; |
|
animation: pulse 1.5s infinite; |
|
} |
|
|
|
@keyframes pulse { |
|
0% { opacity: 1; transform: scale(1); } |
|
50% { opacity: 0.5; transform: scale(1.1); } |
|
100% { opacity: 1; transform: scale(1); } |
|
} |
|
|
|
/* Queue status styling */ |
|
.queue-status { |
|
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%); |
|
border: 1px solid rgba(14, 165, 233, 0.3); |
|
border-radius: 8px; |
|
padding: 0.75rem; |
|
margin: 0.5rem 0; |
|
text-align: center; |
|
font-size: 0.9rem; |
|
color: #0369a1; |
|
} |
|
|
|
.generate-btn { |
|
background: linear-gradient(135deg, #059669 0%, #0d9488 100%); |
|
border: none; |
|
border-radius: 12px; |
|
padding: 1rem 2rem; |
|
color: white; |
|
font-weight: 600; |
|
font-size: 1.1rem; |
|
box-shadow: 0 4px 20px rgba(5, 150, 105, 0.4); |
|
transition: all 0.3s ease; |
|
} |
|
|
|
.generate-btn:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 6px 25px rgba(5, 150, 105, 0.6); |
|
} |
|
|
|
.stop-btn { |
|
background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%); |
|
border: none; |
|
border-radius: 12px; |
|
padding: 1rem 2rem; |
|
color: white; |
|
font-weight: 600; |
|
font-size: 1.1rem; |
|
box-shadow: 0 4px 20px rgba(239, 68, 68, 0.4); |
|
transition: all 0.3s ease; |
|
} |
|
|
|
.stop-btn:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 6px 25px rgba(239, 68, 68, 0.6); |
|
} |
|
|
|
/* Audio player styling */ |
|
.audio-output { |
|
background: linear-gradient(135deg, #f1f5f9 0%, #e2e8f0 100%); |
|
border-radius: 16px; |
|
padding: 1.5rem; |
|
border: 1px solid rgba(148, 163, 184, 0.3); |
|
} |
|
|
|
.complete-audio-section { |
|
margin-top: 1rem; |
|
padding: 1rem; |
|
background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%); |
|
border: 1px solid rgba(34, 197, 94, 0.3); |
|
border-radius: 12px; |
|
} |
|
|
|
/* Text areas */ |
|
.script-input, .log-output { |
|
background: rgba(255, 255, 255, 0.9) !important; |
|
border: 1px solid rgba(148, 163, 184, 0.4) !important; |
|
border-radius: 12px !important; |
|
color: #1e293b !important; |
|
font-family: 'JetBrains Mono', monospace !important; |
|
} |
|
|
|
.script-input::placeholder { |
|
color: #64748b !important; |
|
} |
|
|
|
/* Sliders */ |
|
.slider-container { |
|
background: rgba(248, 250, 252, 0.8); |
|
border: 1px solid rgba(226, 232, 240, 0.6); |
|
border-radius: 8px; |
|
padding: 1rem; |
|
margin: 0.5rem 0; |
|
} |
|
|
|
/* Labels and text */ |
|
.gradio-container label { |
|
color: #374151 !important; |
|
font-weight: 600 !important; |
|
} |
|
|
|
.gradio-container .markdown { |
|
color: #1f2937 !important; |
|
} |
|
|
|
/* Responsive design */ |
|
@media (max-width: 768px) { |
|
.main-header h1 { font-size: 2rem; } |
|
.settings-card, .generation-card { padding: 1rem; } |
|
} |
|
|
|
/* Random example button styling - more subtle professional color */ |
|
.random-btn { |
|
background: linear-gradient(135deg, #64748b 0%, #475569 100%); |
|
border: none; |
|
border-radius: 12px; |
|
padding: 1rem 1.5rem; |
|
color: white; |
|
font-weight: 600; |
|
font-size: 1rem; |
|
box-shadow: 0 4px 20px rgba(100, 116, 139, 0.3); |
|
transition: all 0.3s ease; |
|
display: inline-flex; |
|
align-items: center; |
|
gap: 0.5rem; |
|
} |
|
|
|
.random-btn:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 6px 25px rgba(100, 116, 139, 0.4); |
|
background: linear-gradient(135deg, #475569 0%, #334155 100%); |
|
} |
|
""" |
|
|
|
with gr.Blocks( |
|
title="VibeVoice - AI Podcast Generator", |
|
css=custom_css, |
|
theme=gr.themes.Soft( |
|
primary_hue="blue", |
|
secondary_hue="purple", |
|
neutral_hue="slate", |
|
) |
|
) as interface: |
|
|
|
|
|
gr.HTML(""" |
|
<div class="main-header"> |
|
<h1>ποΈ Vibe Podcasting </h1> |
|
<p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1, elem_classes="settings-card"): |
|
gr.Markdown("### ποΈ **Podcast Settings**") |
|
|
|
|
|
num_speakers = gr.Slider( |
|
minimum=1, |
|
maximum=4, |
|
value=2, |
|
step=1, |
|
label="Number of Speakers", |
|
elem_classes="slider-container" |
|
) |
|
|
|
|
|
gr.Markdown("### π **Speaker Selection**") |
|
|
|
available_speaker_names = list(demo_instance.available_voices.keys()) |
|
|
|
default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman'] |
|
|
|
speaker_selections = [] |
|
for i in range(4): |
|
default_value = default_speakers[i] if i < len(default_speakers) else None |
|
speaker = gr.Dropdown( |
|
choices=available_speaker_names, |
|
value=default_value, |
|
label=f"Speaker {i+1}", |
|
visible=(i < 2), |
|
elem_classes="speaker-item" |
|
) |
|
speaker_selections.append(speaker) |
|
|
|
|
|
gr.Markdown("### βοΈ **Advanced Settings**") |
|
|
|
|
|
with gr.Accordion("Generation Parameters", open=False): |
|
cfg_scale = gr.Slider( |
|
minimum=1.0, |
|
maximum=2.0, |
|
value=1.3, |
|
step=0.05, |
|
label="CFG Scale (Guidance Strength)", |
|
|
|
elem_classes="slider-container" |
|
) |
|
|
|
|
|
with gr.Column(scale=2, elem_classes="generation-card"): |
|
gr.Markdown("### π **Script Input**") |
|
|
|
script_input = gr.Textbox( |
|
label="Conversation Script", |
|
placeholder="""Enter your podcast script here. You can format it as: |
|
|
|
Speaker 0: Welcome to our podcast today! |
|
Speaker 1: Thanks for having me. I'm excited to discuss... |
|
|
|
Or paste text directly and it will auto-assign speakers.""", |
|
lines=12, |
|
max_lines=20, |
|
elem_classes="script-input" |
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
random_example_btn = gr.Button( |
|
"π² Random Example", |
|
size="lg", |
|
variant="secondary", |
|
elem_classes="random-btn", |
|
scale=1 |
|
) |
|
|
|
|
|
generate_btn = gr.Button( |
|
"π Generate Podcast", |
|
size="lg", |
|
variant="primary", |
|
elem_classes="generate-btn", |
|
scale=2 |
|
) |
|
|
|
|
|
stop_btn = gr.Button( |
|
"π Stop Generation", |
|
size="lg", |
|
variant="stop", |
|
elem_classes="stop-btn", |
|
visible=False |
|
) |
|
|
|
|
|
streaming_status = gr.HTML( |
|
value=""" |
|
<div style="background: linear-gradient(135deg, #dcfce7 0%, #bbf7d0 100%); |
|
border: 1px solid rgba(34, 197, 94, 0.3); |
|
border-radius: 8px; |
|
padding: 0.75rem; |
|
margin: 0.5rem 0; |
|
text-align: center; |
|
font-size: 0.9rem; |
|
color: #166534;"> |
|
<span class="streaming-indicator"></span> |
|
<strong>LIVE STREAMING</strong> - Audio is being generated in real-time |
|
</div> |
|
""", |
|
visible=False, |
|
elem_id="streaming-status" |
|
) |
|
|
|
|
|
gr.Markdown("### π΅ **Generated Podcast**") |
|
|
|
|
|
audio_output = gr.Audio( |
|
label="Streaming Audio (Real-time)", |
|
type="numpy", |
|
elem_classes="audio-output", |
|
streaming=True, |
|
autoplay=True, |
|
show_download_button=False, |
|
visible=True |
|
) |
|
|
|
|
|
complete_audio_output = gr.Audio( |
|
label="Complete Podcast (Download after generation)", |
|
type="numpy", |
|
elem_classes="audio-output complete-audio-section", |
|
streaming=False, |
|
autoplay=False, |
|
show_download_button=True, |
|
visible=False |
|
) |
|
|
|
gr.Markdown(""" |
|
*π‘ **Streaming**: Audio plays as it's being generated (may have slight pauses) |
|
*π‘ **Complete Audio**: Will appear below after generation finishes* |
|
""") |
|
|
|
|
|
log_output = gr.Textbox( |
|
label="Generation Log", |
|
lines=8, |
|
max_lines=15, |
|
interactive=False, |
|
elem_classes="log-output" |
|
) |
|
|
|
def update_speaker_visibility(num_speakers): |
|
updates = [] |
|
for i in range(4): |
|
updates.append(gr.update(visible=(i < num_speakers))) |
|
return updates |
|
|
|
num_speakers.change( |
|
fn=update_speaker_visibility, |
|
inputs=[num_speakers], |
|
outputs=speaker_selections |
|
) |
|
|
|
|
|
def generate_podcast_wrapper(num_speakers, script, *speakers_and_params): |
|
"""Wrapper function to handle the streaming generation call.""" |
|
try: |
|
|
|
speakers = speakers_and_params[:4] |
|
cfg_scale = speakers_and_params[4] |
|
|
|
|
|
yield None, gr.update(value=None, visible=False), "ποΈ Starting generation...", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True) |
|
|
|
|
|
final_log = "Starting generation..." |
|
|
|
for streaming_audio, complete_audio, log, streaming_visible in demo_instance.generate_podcast_streaming( |
|
num_speakers=int(num_speakers), |
|
script=script, |
|
speaker_1=speakers[0], |
|
speaker_2=speakers[1], |
|
speaker_3=speakers[2], |
|
speaker_4=speakers[3], |
|
cfg_scale=cfg_scale |
|
): |
|
final_log = log |
|
|
|
|
|
if complete_audio is not None: |
|
|
|
yield None, gr.update(value=complete_audio, visible=True), log, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) |
|
else: |
|
|
|
if streaming_audio is not None: |
|
yield streaming_audio, gr.update(visible=False), log, streaming_visible, gr.update(visible=False), gr.update(visible=True) |
|
else: |
|
|
|
yield None, gr.update(visible=False), log, streaming_visible, gr.update(visible=False), gr.update(visible=True) |
|
|
|
except Exception as e: |
|
error_msg = f"β A critical error occurred in the wrapper: {str(e)}" |
|
print(error_msg) |
|
import traceback |
|
traceback.print_exc() |
|
|
|
yield None, gr.update(value=None, visible=False), error_msg, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) |
|
|
|
def stop_generation_handler(): |
|
"""Handle stopping generation.""" |
|
demo_instance.stop_audio_generation() |
|
|
|
return "π Generation stopped.", gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
def clear_audio_outputs(): |
|
"""Clear both audio outputs before starting new generation.""" |
|
return None, gr.update(value=None, visible=False) |
|
|
|
|
|
generate_btn.click( |
|
fn=clear_audio_outputs, |
|
inputs=[], |
|
outputs=[audio_output, complete_audio_output], |
|
queue=False |
|
).then( |
|
fn=generate_podcast_wrapper, |
|
inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale], |
|
outputs=[audio_output, complete_audio_output, log_output, streaming_status, generate_btn, stop_btn], |
|
queue=True |
|
) |
|
|
|
|
|
stop_btn.click( |
|
fn=stop_generation_handler, |
|
inputs=[], |
|
outputs=[log_output, streaming_status, generate_btn, stop_btn], |
|
queue=False |
|
).then( |
|
|
|
fn=lambda: (None, None), |
|
inputs=[], |
|
outputs=[audio_output, complete_audio_output], |
|
queue=False |
|
) |
|
|
|
|
|
def load_random_example(): |
|
"""Randomly select and load an example script.""" |
|
import random |
|
|
|
|
|
if hasattr(demo_instance, 'example_scripts') and demo_instance.example_scripts: |
|
example_scripts = demo_instance.example_scripts |
|
else: |
|
|
|
example_scripts = [ |
|
[2, "Speaker 0: Welcome to our AI podcast demonstration!\nSpeaker 1: Thanks for having me. This is exciting!"] |
|
] |
|
|
|
|
|
if example_scripts: |
|
selected = random.choice(example_scripts) |
|
num_speakers_value = selected[0] |
|
script_value = selected[1] |
|
|
|
|
|
return num_speakers_value, script_value |
|
|
|
|
|
return 2, "" |
|
|
|
|
|
random_example_btn.click( |
|
fn=load_random_example, |
|
inputs=[], |
|
outputs=[num_speakers, script_input], |
|
queue=False |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
### π‘ **Usage Tips** |
|
|
|
- Click **π Generate Podcast** to start audio generation |
|
- **Live Streaming** tab shows audio as it's generated (may have slight pauses) |
|
- **Complete Audio** tab provides the full, uninterrupted podcast after generation |
|
- During generation, you can click **π Stop Generation** to interrupt the process |
|
- The streaming indicator shows real-time generation progress |
|
""") |
|
|
|
|
|
gr.Markdown("### π **Example Scripts**") |
|
|
|
|
|
if hasattr(demo_instance, 'example_scripts') and demo_instance.example_scripts: |
|
example_scripts = demo_instance.example_scripts |
|
else: |
|
|
|
example_scripts = [ |
|
[1, "Speaker 1: Welcome to our AI podcast demonstration! This is a sample script showing how VibeVoice can generate natural-sounding speech."] |
|
] |
|
|
|
gr.Examples( |
|
examples=example_scripts, |
|
inputs=[num_speakers, script_input], |
|
label="Try these example scripts:" |
|
) |
|
|
|
return interface |
|
|
|
|
|
def convert_to_16_bit_wav(data): |
|
|
|
if torch.is_tensor(data): |
|
data = data.detach().cpu().numpy() |
|
|
|
|
|
data = np.array(data) |
|
|
|
|
|
if np.max(np.abs(data)) > 1.0: |
|
data = data / np.max(np.abs(data)) |
|
|
|
|
|
data = (data * 32767).astype(np.int16) |
|
return data |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="VibeVoice Gradio Demo") |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default="/tmp/vibevoice-model", |
|
help="Path to the VibeVoice model directory", |
|
) |
|
parser.add_argument( |
|
"--device", |
|
type=str, |
|
default="cuda" if torch.cuda.is_available() else "cpu", |
|
help="Device for inference", |
|
) |
|
parser.add_argument( |
|
"--inference_steps", |
|
type=int, |
|
default=10, |
|
help="Number of inference steps for DDPM (not exposed to users)", |
|
) |
|
parser.add_argument( |
|
"--share", |
|
action="store_true", |
|
help="Share the demo publicly via Gradio", |
|
) |
|
parser.add_argument( |
|
"--port", |
|
type=int, |
|
default=7860, |
|
help="Port to run the demo on", |
|
) |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
"""Main function to run the demo.""" |
|
args = parse_args() |
|
|
|
set_seed(42) |
|
|
|
print("ποΈ Initializing VibeVoice Demo with Streaming Support...") |
|
|
|
|
|
demo_instance = VibeVoiceDemo( |
|
model_path=args.model_path, |
|
device=args.device, |
|
inference_steps=args.inference_steps |
|
) |
|
|
|
|
|
interface = create_demo_interface(demo_instance) |
|
|
|
print(f"π Launching demo on port {args.port}") |
|
print(f"π Model path: {args.model_path}") |
|
print(f"π Available voices: {len(demo_instance.available_voices)}") |
|
print(f"π΄ Streaming mode: ENABLED") |
|
print(f"π Session isolation: ENABLED") |
|
|
|
|
|
try: |
|
interface.queue( |
|
max_size=20, |
|
default_concurrency_limit=1 |
|
).launch( |
|
share=args.share, |
|
|
|
server_name="0.0.0.0" if args.share else "127.0.0.1", |
|
show_error=True, |
|
show_api=False |
|
) |
|
except KeyboardInterrupt: |
|
print("\nπ Shutting down gracefully...") |
|
except Exception as e: |
|
print(f"β Server error: {e}") |
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |