Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import torchaudio | |
import numpy as np | |
import tempfile | |
import time | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
import os | |
import spaces | |
from transformers import pipeline | |
# Import the inference module | |
from infer import DMOInference | |
# Global variables | |
model_paths = {"student": None, "duration": None} | |
asr_pipe = None | |
model_downloaded = False | |
# Download models on startup (CPU) | |
def download_models(): | |
"""Download models from HuggingFace Hub.""" | |
global model_downloaded, model_paths | |
try: | |
print("Downloading models from HuggingFace...") | |
# Download student model | |
student_path = hf_hub_download( | |
repo_id="yl4579/DMOSpeech2", | |
filename="model_85000.pt", | |
cache_dir="./models" | |
) | |
# Download duration predictor | |
duration_path = hf_hub_download( | |
repo_id="yl4579/DMOSpeech2", | |
filename="model_1500.pt", | |
cache_dir="./models" | |
) | |
model_paths["student"] = student_path | |
model_paths["duration"] = duration_path | |
model_downloaded = True | |
print(f"✓ Models downloaded successfully") | |
return True | |
except Exception as e: | |
print(f"Error downloading models: {e}") | |
return False | |
# Initialize ASR pipeline on CPU | |
def initialize_asr_pipeline(): | |
"""Initialize the ASR pipeline on startup.""" | |
global asr_pipe | |
print("Initializing ASR pipeline...") | |
try: | |
asr_pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
torch_dtype=torch.float32, | |
device="cpu" # Always use CPU for ASR to save GPU memory | |
) | |
print("✓ ASR pipeline initialized successfully") | |
return True | |
except Exception as e: | |
print(f"Error initializing ASR pipeline: {e}") | |
return False | |
# Transcribe function | |
def transcribe(ref_audio, language=None): | |
"""Transcribe audio using the pre-loaded ASR pipeline.""" | |
global asr_pipe | |
if asr_pipe is None: | |
return "" | |
try: | |
result = asr_pipe( | |
ref_audio, | |
chunk_length_s=30, | |
batch_size=128, | |
generate_kwargs={"task": "transcribe", "language": language} if language else {"task": "transcribe"}, | |
return_timestamps=False, | |
) | |
return result["text"].strip() | |
except Exception as e: | |
print(f"Transcription error: {e}") | |
return "" | |
# Initialize on startup | |
print("Starting DMOSpeech 2...") | |
models_ready = download_models() | |
asr_ready = initialize_asr_pipeline() | |
status_message = f"Models: {'✅' if models_ready else '❌'} | ASR: {'✅' if asr_ready else '❌'}" | |
def generate_speech_gpu( | |
prompt_audio, | |
prompt_text, | |
target_text, | |
mode, | |
temperature, | |
custom_teacher_steps, | |
custom_teacher_stopping_time, | |
custom_student_start_step, | |
verbose | |
): | |
"""Generate speech with GPU acceleration.""" | |
if not model_downloaded: | |
return None, "❌ Models not downloaded! Please refresh the page.", "", "", prompt_text | |
if prompt_audio is None: | |
return None, "❌ Please upload a reference audio!", "", "", prompt_text | |
if not target_text: | |
return None, "❌ Please enter text to generate!", "", "", prompt_text | |
try: | |
# Initialize model on GPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Initializing model on {device}...") | |
model = DMOInference( | |
student_checkpoint_path=model_paths["student"], | |
duration_predictor_path=model_paths["duration"], | |
device=device, | |
model_type="F5TTS_Base" | |
) | |
# Auto-transcribe if needed (this happens on CPU) | |
transcribed_text = prompt_text # Default to provided text | |
if not prompt_text.strip(): | |
print("Auto-transcribing reference audio...") | |
transcribed_text = transcribe(prompt_audio) | |
print(f"Transcribed: {transcribed_text}") | |
start_time = time.time() | |
# Configure parameters based on mode | |
configs = { | |
"Student Only (4 steps)": { | |
"teacher_steps": 0, | |
"student_start_step": 0, | |
"teacher_stopping_time": 1.0 | |
}, | |
"Teacher-Guided (8 steps)": { | |
"teacher_steps": 16, | |
"teacher_stopping_time": 0.07, | |
"student_start_step": 1 | |
}, | |
"High Diversity (16 steps)": { | |
"teacher_steps": 24, | |
"teacher_stopping_time": 0.3, | |
"student_start_step": 2 | |
}, | |
"Custom": { | |
"teacher_steps": custom_teacher_steps, | |
"teacher_stopping_time": custom_teacher_stopping_time, | |
"student_start_step": custom_student_start_step | |
} | |
} | |
config = configs[mode] | |
# Generate speech | |
generated_audio = model.generate( | |
gen_text=target_text, | |
audio_path=prompt_audio, | |
prompt_text=transcribed_text if transcribed_text else None, | |
teacher_steps=config["teacher_steps"], | |
teacher_stopping_time=config["teacher_stopping_time"], | |
student_start_step=config["student_start_step"], | |
temperature=temperature, | |
verbose=verbose | |
) | |
end_time = time.time() | |
# Calculate metrics | |
processing_time = end_time - start_time | |
audio_duration = generated_audio.shape[-1] / 24000 | |
rtf = processing_time / audio_duration | |
# Save audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
output_path = tmp_file.name | |
if isinstance(generated_audio, np.ndarray): | |
generated_audio = torch.from_numpy(generated_audio) | |
if generated_audio.dim() == 1: | |
generated_audio = generated_audio.unsqueeze(0) | |
torchaudio.save(output_path, generated_audio, 24000) | |
# Format output | |
metrics = f"""RTF: {rtf:.2f}x ({1/rtf:.2f}x faster) | |
Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio | |
Device: {device.upper()}""" | |
info = f"Mode: {mode}" | |
if not prompt_text.strip(): | |
info += f" | Auto-transcribed" | |
# Clean up GPU memory | |
del model | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
# Return transcribed text to update the textbox | |
return output_path, "✅ Success!", metrics, info, transcribed_text | |
except Exception as e: | |
import traceback | |
print(traceback.format_exc()) | |
return None, f"❌ Error: {str(e)}", "", "", prompt_text | |
# Create Gradio interface | |
with gr.Blocks( | |
title="DMOSpeech 2 - Zero-Shot TTS", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { max-width: 1200px !important; } | |
""" | |
) as demo: | |
gr.Markdown(f""" | |
<div style="text-align: center;"> | |
<h1>🎙️ DMOSpeech 2: Zero-Shot Text-to-Speech</h1> | |
<p>Generate natural speech in any voice with just a 3-10 second reference!</p> | |
<p><b>System Status:</b> {status_message}</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
# Inputs | |
prompt_audio = gr.Audio( | |
label="📎 Reference Audio (3-10 seconds)", | |
type="filepath", | |
sources=["upload", "microphone"] | |
) | |
prompt_text = gr.Textbox( | |
label="📝 Reference Text (leave empty for auto-transcription)", | |
placeholder="The text spoken in the reference audio...", | |
lines=2 | |
) | |
target_text = gr.Textbox( | |
label="✍️ Text to Generate", | |
placeholder="Enter the text you want to synthesize...", | |
lines=4 | |
) | |
mode = gr.Radio( | |
choices=[ | |
"Student Only (4 steps)", | |
"Teacher-Guided (8 steps)", | |
"High Diversity (16 steps)", | |
"Custom" | |
], | |
value="Teacher-Guided (8 steps)", | |
label="🚀 Generation Mode", | |
info="Speed vs quality tradeoff" | |
) | |
# Advanced settings | |
with gr.Accordion("⚙️ Advanced Settings", open=False): | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.0, | |
step=0.1, | |
label="Duration Temperature", | |
info="0 = consistent, >0 = varied rhythm" | |
) | |
with gr.Group(visible=False) as custom_group: | |
custom_teacher_steps = gr.Slider(0, 32, 16, 1, label="Teacher Steps") | |
custom_teacher_stopping_time = gr.Slider(0.0, 1.0, 0.07, 0.01, label="Stopping Time") | |
custom_student_start_step = gr.Slider(0, 4, 1, 1, label="Student Start Step") | |
verbose = gr.Checkbox(False, label="Verbose Output") | |
generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
# Outputs | |
output_audio = gr.Audio( | |
label="🔊 Generated Speech", | |
type="filepath", | |
autoplay=True | |
) | |
status = gr.Textbox(label="Status", interactive=False) | |
metrics = gr.Textbox(label="Performance", interactive=False, lines=3) | |
info = gr.Textbox(label="Info", interactive=False) | |
# Guide | |
gr.Markdown(""" | |
### 💡 Quick Guide | |
| Mode | Speed | Quality | Use Case | | |
|------|-------|---------|----------| | |
| Student Only | 20x realtime | Good | Real-time apps | | |
| Teacher-Guided | 10x realtime | Better | General use | | |
| High Diversity | 5x realtime | Best | Production | | |
**Tips:** | |
- Leave reference text empty for auto-transcription | |
- Auto-transcription only happens once - the text will be filled in | |
- Use temperature > 0 for more natural rhythm variation | |
- Custom mode lets you fine-tune all parameters | |
""") | |
# Examples | |
gr.Markdown("### 🎯 Example Texts") | |
gr.Markdown(""" | |
<details> | |
<summary>English Example</summary> | |
**Reference:** "Some call me nature, others call me mother nature." | |
**Target:** "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring." | |
</details> | |
<details> | |
<summary>Chinese Example</summary> | |
**Reference:** "对,这就是我,万人敬仰的太乙真人。" | |
**Target:** "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:'我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?'" | |
</details> | |
""") | |
# Event handlers | |
def toggle_custom(mode): | |
return gr.update(visible=(mode == "Custom")) | |
mode.change(toggle_custom, [mode], [custom_group]) | |
generate_btn.click( | |
generate_speech_gpu, | |
inputs=[ | |
prompt_audio, | |
prompt_text, | |
target_text, | |
mode, | |
temperature, | |
custom_teacher_steps, | |
custom_teacher_stopping_time, | |
custom_student_start_step, | |
verbose | |
], | |
outputs=[ | |
output_audio, | |
status, | |
metrics, | |
info, | |
prompt_text # Update the prompt_text textbox with transcribed text | |
] | |
) | |
# Launch | |
if __name__ == "__main__": | |
demo.launch() |