Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import random | |
import numpy as np | |
import soundfile as sf | |
import tempfile | |
import uuid | |
import logging | |
import requests | |
import io | |
from typing import Optional, Dict, Any | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
# ChatterboxTTS import - you need to install this separately | |
# For now, we'll create a mock implementation that you can replace | |
try: | |
from chatterbox.src.chatterbox.tts import ChatterboxTTS | |
CHATTERBOX_AVAILABLE = True | |
except ImportError: | |
CHATTERBOX_AVAILABLE = False | |
print("β οΈ ChatterboxTTS not found. Using mock implementation.") | |
print("π¦ Install ChatterboxTTS: pip install chatterbox-tts") | |
# Mock ChatterboxTTS for demonstration | |
class ChatterboxTTS: | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.sr = 24000 | |
def from_pretrained(cls, device): | |
return cls(device) | |
def to(self, device): | |
self.device = device | |
return self | |
def generate(self, text, audio_prompt_path=None, exaggeration=0.5, | |
temperature=0.8, cfg_weight=0.5): | |
# Generate mock audio - replace this with real ChatterboxTTS | |
duration = min(len(text) * 0.1, 10.0) | |
t = np.linspace(0, duration, int(self.sr * duration)) | |
# Create more realistic mock audio | |
words = len(text.split()) | |
freq_base = 150 + (words % 50) * 5 # Vary by content | |
# Generate speech-like waveform | |
audio = np.zeros_like(t) | |
for i in range(3): # Multiple harmonics | |
freq = freq_base * (i + 1) | |
envelope = np.exp(-t / (duration * 0.7)) | |
wave = 0.2 * np.sin(2 * np.pi * freq * t + i) * envelope | |
audio += wave | |
# Add some variation based on parameters | |
audio *= (0.5 + exaggeration) | |
if temperature > 1.0: | |
noise = np.random.normal(0, 0.05, len(audio)) | |
audio += noise | |
return torch.tensor(audio).unsqueeze(0) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Device configuration | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"π Running on device: {DEVICE}") | |
# Global model variable | |
MODEL = None | |
# Storage for generated audio | |
AUDIO_DIR = "generated_audio" | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
audio_cache = {} | |
def get_or_load_model(): | |
"""Load ChatterboxTTS model if not already loaded""" | |
global MODEL | |
if MODEL is None: | |
logger.info("Loading ChatterboxTTS model...") | |
try: | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
if hasattr(MODEL, 'to'): | |
MODEL.to(DEVICE) | |
logger.info("β ChatterboxTTS model loaded successfully") | |
except Exception as e: | |
logger.error(f"β Error loading model: {e}") | |
raise | |
return MODEL | |
def set_seed(seed: int): | |
"""Set random seed for reproducibility""" | |
torch.manual_seed(seed) | |
if DEVICE == "cuda": | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
def generate_id(): | |
"""Generate unique ID""" | |
return str(uuid.uuid4()) | |
# Pydantic models for API | |
class TTSRequest(BaseModel): | |
text: str | |
audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" | |
exaggeration: Optional[float] = 0.5 | |
temperature: Optional[float] = 0.8 | |
cfg_weight: Optional[float] = 0.5 | |
seed: Optional[int] = 0 | |
class TTSResponse(BaseModel): | |
success: bool | |
audio_id: Optional[str] = None | |
message: str | |
sample_rate: Optional[int] = None | |
duration: Optional[float] = None | |
# Load model at startup | |
try: | |
if CHATTERBOX_AVAILABLE: | |
get_or_load_model() | |
print("β ChatterboxTTS model loaded successfully") | |
else: | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
print("β οΈ Using mock ChatterboxTTS implementation") | |
except Exception as e: | |
logger.error(f"Failed to load model on startup: {e}") | |
MODEL = None | |
def generate_tts_audio( | |
text_input: str, | |
audio_prompt_path_input: str, | |
exaggeration_input: float, | |
temperature_input: float, | |
seed_num_input: int, | |
cfgw_input: float | |
) -> tuple[int, np.ndarray]: | |
""" | |
Generate TTS audio using ChatterboxTTS model | |
""" | |
current_model = get_or_load_model() | |
if current_model is None: | |
raise RuntimeError("TTS model is not loaded") | |
if seed_num_input != 0: | |
set_seed(int(seed_num_input)) | |
logger.info(f"π΅ Generating audio for: '{text_input[:50]}...'") | |
try: | |
wav = current_model.generate( | |
text_input[:300], # Limit text length | |
audio_prompt_path=audio_prompt_path_input, | |
exaggeration=exaggeration_input, | |
temperature=temperature_input, | |
cfg_weight=cfgw_input, | |
) | |
logger.info("β Audio generation complete") | |
return (current_model.sr, wav.squeeze(0).numpy()) | |
except Exception as e: | |
logger.error(f"β Audio generation failed: {e}") | |
raise | |
# FastAPI app for API endpoints | |
app = FastAPI( | |
title="ChatterboxTTS API", | |
description="High-quality text-to-speech synthesis using ChatterboxTTS", | |
version="1.0.0" | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
"""API status endpoint""" | |
return { | |
"service": "ChatterboxTTS API", | |
"version": "1.0.0", | |
"status": "operational" if MODEL else "model_loading", | |
"model_loaded": MODEL is not None, | |
"device": DEVICE, | |
"endpoints": { | |
"synthesize": "/api/tts/synthesize", | |
"audio": "/api/audio/{audio_id}", | |
"health": "/health" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy" if MODEL else "unhealthy", | |
"model_loaded": MODEL is not None, | |
"device": DEVICE, | |
"timestamp": time.time() | |
} | |
async def synthesize_speech(request: TTSRequest): | |
""" | |
Synthesize speech from text | |
""" | |
try: | |
if MODEL is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
if len(request.text) > 500: | |
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
start_time = time.time() | |
# Generate audio | |
sample_rate, audio_data = generate_tts_audio( | |
request.text, | |
request.audio_prompt_url, | |
request.exaggeration, | |
request.temperature, | |
request.seed, | |
request.cfg_weight | |
) | |
generation_time = time.time() - start_time | |
# Save audio file | |
audio_id = generate_id() | |
audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav") | |
sf.write(audio_path, audio_data, sample_rate) | |
# Cache audio info | |
audio_cache[audio_id] = { | |
"path": audio_path, | |
"text": request.text, | |
"sample_rate": sample_rate, | |
"duration": len(audio_data) / sample_rate, | |
"generated_at": time.time(), | |
"generation_time": generation_time | |
} | |
logger.info(f"β Audio saved: {audio_id} ({generation_time:.2f}s)") | |
return TTSResponse( | |
success=True, | |
audio_id=audio_id, | |
message="Speech synthesized successfully", | |
sample_rate=sample_rate, | |
duration=len(audio_data) / sample_rate | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"β Synthesis failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
async def get_audio(audio_id: str): | |
""" | |
Download generated audio file | |
""" | |
if audio_id not in audio_cache: | |
raise HTTPException(status_code=404, detail="Audio not found") | |
audio_info = audio_cache[audio_id] | |
audio_path = audio_info["path"] | |
if not os.path.exists(audio_path): | |
raise HTTPException(status_code=404, detail="Audio file not found on disk") | |
def iterfile(): | |
with open(audio_path, "rb") as f: | |
yield from f | |
return StreamingResponse( | |
iterfile(), | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": f"attachment; filename=tts_{audio_id}.wav" | |
} | |
) | |
async def get_audio_info(audio_id: str): | |
""" | |
Get audio file information | |
""" | |
if audio_id not in audio_cache: | |
raise HTTPException(status_code=404, detail="Audio not found") | |
return audio_cache[audio_id] | |
async def list_audio(): | |
""" | |
List all generated audio files | |
""" | |
return { | |
"audio_files": [ | |
{ | |
"audio_id": audio_id, | |
"text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"], | |
"duration": info["duration"], | |
"generated_at": info["generated_at"] | |
} | |
for audio_id, info in audio_cache.items() | |
], | |
"total": len(audio_cache) | |
} | |
# Gradio interface | |
def create_gradio_interface(): | |
"""Create simple Gradio interface""" | |
with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π΅ ChatterboxTTS | |
High-quality text-to-speech synthesis with voice cloning capabilities. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
value="Hello, this is ChatterboxTTS. I can generate natural-sounding speech from any text you provide.", | |
label="Text to synthesize (max 300 characters)", | |
max_lines=5, | |
placeholder="Enter your text here..." | |
) | |
audio_prompt = gr.Textbox( | |
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
label="Reference Audio URL", | |
placeholder="URL to reference audio file" | |
) | |
with gr.Row(): | |
exaggeration = gr.Slider( | |
0.25, 2, | |
step=0.05, | |
label="Exaggeration", | |
value=0.5, | |
info="Controls expressiveness (0.5 = neutral)" | |
) | |
cfg_weight = gr.Slider( | |
0.2, 1, | |
step=0.05, | |
label="CFG Weight", | |
value=0.5, | |
info="Controls pace and clarity" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature = gr.Slider( | |
0.05, 5, | |
step=0.05, | |
label="Temperature", | |
value=0.8, | |
info="Controls randomness" | |
) | |
seed = gr.Number( | |
value=0, | |
label="Seed (0 = random)", | |
info="Set to non-zero for reproducible results" | |
) | |
generate_btn = gr.Button("π΅ Generate Speech", variant="primary") | |
with gr.Column(): | |
audio_output = gr.Audio(label="Generated Speech") | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False, | |
placeholder="Click 'Generate Speech' to start..." | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
[ | |
"Welcome to our podcast! Today we're discussing the latest developments in artificial intelligence.", | |
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
0.6, 0.8, 0, 0.5 | |
], | |
[ | |
"Good morning! I hope you're having a wonderful day. Let me tell you about our exciting new features.", | |
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
0.7, 0.9, 0, 0.6 | |
], | |
[ | |
"In today's tutorial, we'll learn how to build a machine learning model from scratch using Python.", | |
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
0.4, 0.7, 0, 0.4 | |
] | |
], | |
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight] | |
) | |
def generate_speech_ui(text, prompt_url, exag, temp, seed_val, cfg): | |
"""Generate speech from UI""" | |
try: | |
if not text.strip(): | |
return None, "β Please enter some text" | |
if len(text) > 300: | |
return None, "β Text too long (max 300 characters)" | |
start_time = time.time() | |
# Generate audio | |
sample_rate, audio_data = generate_tts_audio( | |
text, prompt_url, exag, temp, int(seed_val), cfg | |
) | |
generation_time = time.time() - start_time | |
duration = len(audio_data) / sample_rate | |
status = f"""β Speech generated successfully! | |
β±οΈ Generation time: {generation_time:.2f}s | |
π΅ Audio duration: {duration:.2f}s | |
π Sample rate: {sample_rate} Hz | |
π Audio samples: {len(audio_data):,} | |
""" | |
return (sample_rate, audio_data), status | |
except Exception as e: | |
logger.error(f"UI generation failed: {e}") | |
return None, f"β Generation failed: {str(e)}" | |
generate_btn.click( | |
fn=generate_speech_ui, | |
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight], | |
outputs=[audio_output, status_text] | |
) | |
# API Documentation | |
gr.Markdown(""" | |
## π API Endpoints | |
### POST `/api/tts/synthesize` | |
Generate speech from text | |
```json | |
{ | |
"text": "Your text here", | |
"audio_prompt_url": "URL to reference audio", | |
"exaggeration": 0.5, | |
"temperature": 0.8, | |
"cfg_weight": 0.5, | |
"seed": 0 | |
} | |
``` | |
### GET `/api/audio/{audio_id}` | |
Download generated audio file | |
### GET `/api/audio` | |
List all generated audio files | |
### GET `/health` | |
Check service health | |
""") | |
# System info | |
model_status = "β Real ChatterboxTTS" if CHATTERBOX_AVAILABLE and MODEL else "β οΈ Mock Implementation" if MODEL else "β Not Loaded" | |
gr.Markdown(f""" | |
### π System Status | |
- **Model**: {model_status} | |
- **Device**: {DEVICE} | |
- **Generated Files**: {len(audio_cache)} | |
- **ChatterboxTTS Available**: {CHATTERBOX_AVAILABLE} | |
{"" if CHATTERBOX_AVAILABLE else "**Note**: Install ChatterboxTTS for production use: `pip install chatterbox-tts`"} | |
""") | |
return demo | |
# Main execution | |
if __name__ == "__main__": | |
logger.info("π Starting ChatterboxTTS Service...") | |
# Model status | |
if CHATTERBOX_AVAILABLE and MODEL: | |
model_status = "β Real ChatterboxTTS Loaded" | |
elif MODEL: | |
model_status = "β οΈ Mock ChatterboxTTS (Install real package for production)" | |
else: | |
model_status = "β No Model Loaded" | |
logger.info(f"Model Status: {model_status}") | |
logger.info(f"Device: {DEVICE}") | |
logger.info(f"ChatterboxTTS Available: {CHATTERBOX_AVAILABLE}") | |
if os.getenv("SPACE_ID"): | |
# Running in Hugging Face Spaces | |
logger.info("π Running in Hugging Face Spaces") | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |
else: | |
# Local development - run both FastAPI and Gradio | |
import uvicorn | |
import threading | |
def run_fastapi(): | |
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") | |
# Start FastAPI in background | |
api_thread = threading.Thread(target=run_fastapi, daemon=True) | |
api_thread.start() | |
logger.info("π FastAPI: http://localhost:8000") | |
logger.info("π API Docs: http://localhost:8000/docs") | |
# Start Gradio | |
demo = create_gradio_interface() | |
demo.launch(share=True) |