Chatterbox_AP / app.py
ceymox's picture
Create app.py
e1ff5d6 verified
raw
history blame
15.4 kB
import os
import time
import torch
import numpy as np
import soundfile as sf
import tempfile
import uuid
import logging
from typing import Optional, Dict, Any
from pathlib import Path
import gradio as gr
import spaces
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# Import ChatterboxTTS
from chatterbox.src.chatterbox.tts import ChatterboxTTS
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"πŸš€ Running on device: {DEVICE}")
# Global model variable
MODEL = None
# Storage for generated audio
AUDIO_DIR = "generated_audio"
os.makedirs(AUDIO_DIR, exist_ok=True)
audio_cache = {}
def get_or_load_model():
"""Load ChatterboxTTS model if not already loaded"""
global MODEL
if MODEL is None:
logger.info("Loading ChatterboxTTS model...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to'):
MODEL.to(DEVICE)
logger.info("βœ… ChatterboxTTS model loaded successfully")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
raise
return MODEL
def set_seed(seed: int):
"""Set random seed for reproducibility"""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
def generate_id():
"""Generate unique ID"""
return str(uuid.uuid4())
# Pydantic models for API
class TTSRequest(BaseModel):
text: str
audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
exaggeration: Optional[float] = 0.5
temperature: Optional[float] = 0.8
cfg_weight: Optional[float] = 0.5
seed: Optional[int] = 0
class TTSResponse(BaseModel):
success: bool
audio_id: Optional[str] = None
message: str
sample_rate: Optional[int] = None
duration: Optional[float] = None
# Load model at startup
try:
get_or_load_model()
except Exception as e:
logger.error(f"Failed to load model on startup: {e}")
@spaces.GPU
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str,
exaggeration_input: float,
temperature_input: float,
seed_num_input: int,
cfgw_input: float
) -> tuple[int, np.ndarray]:
"""
Generate TTS audio using ChatterboxTTS model
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded")
if seed_num_input != 0:
set_seed(int(seed_num_input))
logger.info(f"🎡 Generating audio for: '{text_input[:50]}...'")
try:
wav = current_model.generate(
text_input[:300], # Limit text length
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
logger.info("βœ… Audio generation complete")
return (current_model.sr, wav.squeeze(0).numpy())
except Exception as e:
logger.error(f"❌ Audio generation failed: {e}")
raise
# FastAPI app for API endpoints
app = FastAPI(
title="ChatterboxTTS API",
description="High-quality text-to-speech synthesis using ChatterboxTTS",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""API status endpoint"""
return {
"service": "ChatterboxTTS API",
"version": "1.0.0",
"status": "operational" if MODEL else "model_loading",
"model_loaded": MODEL is not None,
"device": DEVICE,
"endpoints": {
"synthesize": "/api/tts/synthesize",
"audio": "/api/audio/{audio_id}",
"health": "/health"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if MODEL else "unhealthy",
"model_loaded": MODEL is not None,
"device": DEVICE,
"timestamp": time.time()
}
@app.post("/api/tts/synthesize", response_model=TTSResponse)
async def synthesize_speech(request: TTSRequest):
"""
Synthesize speech from text
"""
try:
if MODEL is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(request.text) > 500:
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
start_time = time.time()
# Generate audio
sample_rate, audio_data = generate_tts_audio(
request.text,
request.audio_prompt_url,
request.exaggeration,
request.temperature,
request.seed,
request.cfg_weight
)
generation_time = time.time() - start_time
# Save audio file
audio_id = generate_id()
audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav")
sf.write(audio_path, audio_data, sample_rate)
# Cache audio info
audio_cache[audio_id] = {
"path": audio_path,
"text": request.text,
"sample_rate": sample_rate,
"duration": len(audio_data) / sample_rate,
"generated_at": time.time(),
"generation_time": generation_time
}
logger.info(f"βœ… Audio saved: {audio_id} ({generation_time:.2f}s)")
return TTSResponse(
success=True,
audio_id=audio_id,
message="Speech synthesized successfully",
sample_rate=sample_rate,
duration=len(audio_data) / sample_rate
)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
@app.get("/api/audio/{audio_id}")
async def get_audio(audio_id: str):
"""
Download generated audio file
"""
if audio_id not in audio_cache:
raise HTTPException(status_code=404, detail="Audio not found")
audio_info = audio_cache[audio_id]
audio_path = audio_info["path"]
if not os.path.exists(audio_path):
raise HTTPException(status_code=404, detail="Audio file not found on disk")
def iterfile():
with open(audio_path, "rb") as f:
yield from f
return StreamingResponse(
iterfile(),
media_type="audio/wav",
headers={
"Content-Disposition": f"attachment; filename=tts_{audio_id}.wav"
}
)
@app.get("/api/audio/{audio_id}/info")
async def get_audio_info(audio_id: str):
"""
Get audio file information
"""
if audio_id not in audio_cache:
raise HTTPException(status_code=404, detail="Audio not found")
return audio_cache[audio_id]
@app.get("/api/audio")
async def list_audio():
"""
List all generated audio files
"""
return {
"audio_files": [
{
"audio_id": audio_id,
"text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"],
"duration": info["duration"],
"generated_at": info["generated_at"]
}
for audio_id, info in audio_cache.items()
],
"total": len(audio_cache)
}
# Gradio interface
def create_gradio_interface():
"""Create simple Gradio interface"""
with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎡 ChatterboxTTS
High-quality text-to-speech synthesis with voice cloning capabilities.
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
value="Hello, this is ChatterboxTTS. I can generate natural-sounding speech from any text you provide.",
label="Text to synthesize (max 300 characters)",
max_lines=5,
placeholder="Enter your text here..."
)
audio_prompt = gr.Textbox(
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
label="Reference Audio URL",
placeholder="URL to reference audio file"
)
with gr.Row():
exaggeration = gr.Slider(
0.25, 2,
step=0.05,
label="Exaggeration",
value=0.5,
info="Controls expressiveness (0.5 = neutral)"
)
cfg_weight = gr.Slider(
0.2, 1,
step=0.05,
label="CFG Weight",
value=0.5,
info="Controls pace and clarity"
)
with gr.Accordion("Advanced Settings", open=False):
temperature = gr.Slider(
0.05, 5,
step=0.05,
label="Temperature",
value=0.8,
info="Controls randomness"
)
seed = gr.Number(
value=0,
label="Seed (0 = random)",
info="Set to non-zero for reproducible results"
)
generate_btn = gr.Button("🎡 Generate Speech", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Generated Speech")
status_text = gr.Textbox(
label="Status",
interactive=False,
placeholder="Click 'Generate Speech' to start..."
)
# Examples
gr.Examples(
examples=[
[
"Welcome to our podcast! Today we're discussing the latest developments in artificial intelligence.",
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
0.6, 0.8, 0, 0.5
],
[
"Good morning! I hope you're having a wonderful day. Let me tell you about our exciting new features.",
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
0.7, 0.9, 0, 0.6
],
[
"In today's tutorial, we'll learn how to build a machine learning model from scratch using Python.",
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
0.4, 0.7, 0, 0.4
]
],
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight]
)
def generate_speech_ui(text, prompt_url, exag, temp, seed_val, cfg):
"""Generate speech from UI"""
try:
if not text.strip():
return None, "❌ Please enter some text"
if len(text) > 300:
return None, "❌ Text too long (max 300 characters)"
start_time = time.time()
# Generate audio
sample_rate, audio_data = generate_tts_audio(
text, prompt_url, exag, temp, int(seed_val), cfg
)
generation_time = time.time() - start_time
duration = len(audio_data) / sample_rate
status = f"""βœ… Speech generated successfully!
⏱️ Generation time: {generation_time:.2f}s
🎡 Audio duration: {duration:.2f}s
πŸ“Š Sample rate: {sample_rate} Hz
πŸ”Š Audio samples: {len(audio_data):,}
"""
return (sample_rate, audio_data), status
except Exception as e:
logger.error(f"UI generation failed: {e}")
return None, f"❌ Generation failed: {str(e)}"
generate_btn.click(
fn=generate_speech_ui,
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight],
outputs=[audio_output, status_text]
)
# API Documentation
gr.Markdown("""
## πŸ”Œ API Endpoints
### POST `/api/tts/synthesize`
Generate speech from text
```json
{
"text": "Your text here",
"audio_prompt_url": "URL to reference audio",
"exaggeration": 0.5,
"temperature": 0.8,
"cfg_weight": 0.5,
"seed": 0
}
```
### GET `/api/audio/{audio_id}`
Download generated audio file
### GET `/api/audio`
List all generated audio files
### GET `/health`
Check service health
""")
# System info
model_status = "βœ… Loaded" if MODEL else "❌ Not Loaded"
gr.Markdown(f"""
### πŸ“Š System Status
- **Model**: {model_status}
- **Device**: {DEVICE}
- **Generated Files**: {len(audio_cache)}
""")
return demo
# Main execution
if __name__ == "__main__":
logger.info("πŸŽ‰ Starting ChatterboxTTS Service...")
# Model status
model_status = "βœ… Loaded" if MODEL else "❌ Not Loaded"
logger.info(f"Model Status: {model_status}")
logger.info(f"Device: {DEVICE}")
if os.getenv("SPACE_ID"):
# Running in Hugging Face Spaces
logger.info("🏠 Running in Hugging Face Spaces")
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
else:
# Local development - run both FastAPI and Gradio
import uvicorn
import threading
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
# Start FastAPI in background
api_thread = threading.Thread(target=run_fastapi, daemon=True)
api_thread.start()
logger.info("🌐 FastAPI: http://localhost:8000")
logger.info("πŸ“š API Docs: http://localhost:8000/docs")
# Start Gradio
demo = create_gradio_interface()
demo.launch(share=True)