Chatterbox_AP / app.py
ceymox's picture
Update app.py
e36284f verified
raw
history blame
18.1 kB
import os
import time
import torch
import random
import numpy as np
import soundfile as sf
import tempfile
import uuid
import logging
import requests
import io
from typing import Optional, Dict, Any
from pathlib import Path
import gradio as gr
import spaces
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# ChatterboxTTS import - you need to install this separately
# For now, we'll create a mock implementation that you can replace
try:
from chatterbox.src.chatterbox.tts import ChatterboxTTS
CHATTERBOX_AVAILABLE = True
except ImportError:
CHATTERBOX_AVAILABLE = False
print("⚠️ ChatterboxTTS not found. Using mock implementation.")
print("πŸ“¦ Install ChatterboxTTS: pip install chatterbox-tts")
# Mock ChatterboxTTS for demonstration
class ChatterboxTTS:
def __init__(self, device="cpu"):
self.device = device
self.sr = 24000
@classmethod
def from_pretrained(cls, device):
return cls(device)
def to(self, device):
self.device = device
return self
def generate(self, text, audio_prompt_path=None, exaggeration=0.5,
temperature=0.8, cfg_weight=0.5):
# Generate mock audio - replace this with real ChatterboxTTS
duration = min(len(text) * 0.1, 10.0)
t = np.linspace(0, duration, int(self.sr * duration))
# Create more realistic mock audio
words = len(text.split())
freq_base = 150 + (words % 50) * 5 # Vary by content
# Generate speech-like waveform
audio = np.zeros_like(t)
for i in range(3): # Multiple harmonics
freq = freq_base * (i + 1)
envelope = np.exp(-t / (duration * 0.7))
wave = 0.2 * np.sin(2 * np.pi * freq * t + i) * envelope
audio += wave
# Add some variation based on parameters
audio *= (0.5 + exaggeration)
if temperature > 1.0:
noise = np.random.normal(0, 0.05, len(audio))
audio += noise
return torch.tensor(audio).unsqueeze(0)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"πŸš€ Running on device: {DEVICE}")
# Global model variable
MODEL = None
# Storage for generated audio
AUDIO_DIR = "generated_audio"
os.makedirs(AUDIO_DIR, exist_ok=True)
audio_cache = {}
def get_or_load_model():
"""Load ChatterboxTTS model if not already loaded"""
global MODEL
if MODEL is None:
logger.info("Loading ChatterboxTTS model...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to'):
MODEL.to(DEVICE)
logger.info("βœ… ChatterboxTTS model loaded successfully")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
raise
return MODEL
def set_seed(seed: int):
"""Set random seed for reproducibility"""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def generate_id():
"""Generate unique ID"""
return str(uuid.uuid4())
# Pydantic models for API
class TTSRequest(BaseModel):
text: str
audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
exaggeration: Optional[float] = 0.5
temperature: Optional[float] = 0.8
cfg_weight: Optional[float] = 0.5
seed: Optional[int] = 0
class TTSResponse(BaseModel):
success: bool
audio_id: Optional[str] = None
message: str
sample_rate: Optional[int] = None
duration: Optional[float] = None
# Load model at startup
try:
if CHATTERBOX_AVAILABLE:
get_or_load_model()
print("βœ… ChatterboxTTS model loaded successfully")
else:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
print("⚠️ Using mock ChatterboxTTS implementation")
except Exception as e:
logger.error(f"Failed to load model on startup: {e}")
MODEL = None
@spaces.GPU
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str,
exaggeration_input: float,
temperature_input: float,
seed_num_input: int,
cfgw_input: float
) -> tuple[int, np.ndarray]:
"""
Generate TTS audio using ChatterboxTTS model
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded")
if seed_num_input != 0:
set_seed(int(seed_num_input))
logger.info(f"🎡 Generating audio for: '{text_input[:50]}...'")
try:
wav = current_model.generate(
text_input[:300], # Limit text length
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
logger.info("βœ… Audio generation complete")
return (current_model.sr, wav.squeeze(0).numpy())
except Exception as e:
logger.error(f"❌ Audio generation failed: {e}")
raise
# FastAPI app for API endpoints
app = FastAPI(
title="ChatterboxTTS API",
description="High-quality text-to-speech synthesis using ChatterboxTTS",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""API status endpoint"""
return {
"service": "ChatterboxTTS API",
"version": "1.0.0",
"status": "operational" if MODEL else "model_loading",
"model_loaded": MODEL is not None,
"device": DEVICE,
"endpoints": {
"synthesize": "/api/tts/synthesize",
"audio": "/api/audio/{audio_id}",
"health": "/health"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if MODEL else "unhealthy",
"model_loaded": MODEL is not None,
"device": DEVICE,
"timestamp": time.time()
}
@app.post("/api/tts/synthesize", response_model=TTSResponse)
async def synthesize_speech(request: TTSRequest):
"""
Synthesize speech from text
"""
try:
if MODEL is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(request.text) > 500:
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
start_time = time.time()
# Generate audio
sample_rate, audio_data = generate_tts_audio(
request.text,
request.audio_prompt_url,
request.exaggeration,
request.temperature,
request.seed,
request.cfg_weight
)
generation_time = time.time() - start_time
# Save audio file
audio_id = generate_id()
audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav")
sf.write(audio_path, audio_data, sample_rate)
# Cache audio info
audio_cache[audio_id] = {
"path": audio_path,
"text": request.text,
"sample_rate": sample_rate,
"duration": len(audio_data) / sample_rate,
"generated_at": time.time(),
"generation_time": generation_time
}
logger.info(f"βœ… Audio saved: {audio_id} ({generation_time:.2f}s)")
return TTSResponse(
success=True,
audio_id=audio_id,
message="Speech synthesized successfully",
sample_rate=sample_rate,
duration=len(audio_data) / sample_rate
)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Synthesis failed: {e}")
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
@app.get("/api/audio/{audio_id}")
async def get_audio(audio_id: str):
"""
Download generated audio file
"""
if audio_id not in audio_cache:
raise HTTPException(status_code=404, detail="Audio not found")
audio_info = audio_cache[audio_id]
audio_path = audio_info["path"]
if not os.path.exists(audio_path):
raise HTTPException(status_code=404, detail="Audio file not found on disk")
def iterfile():
with open(audio_path, "rb") as f:
yield from f
return StreamingResponse(
iterfile(),
media_type="audio/wav",
headers={
"Content-Disposition": f"attachment; filename=tts_{audio_id}.wav"
}
)
@app.get("/api/audio/{audio_id}/info")
async def get_audio_info(audio_id: str):
"""
Get audio file information
"""
if audio_id not in audio_cache:
raise HTTPException(status_code=404, detail="Audio not found")
return audio_cache[audio_id]
@app.get("/api/audio")
async def list_audio():
"""
List all generated audio files
"""
return {
"audio_files": [
{
"audio_id": audio_id,
"text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"],
"duration": info["duration"],
"generated_at": info["generated_at"]
}
for audio_id, info in audio_cache.items()
],
"total": len(audio_cache)
}
# Gradio interface
def create_gradio_interface():
"""Create simple Gradio interface"""
with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎡 ChatterboxTTS
High-quality text-to-speech synthesis with voice cloning capabilities.
""")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
value="Hello, this is ChatterboxTTS. I can generate natural-sounding speech from any text you provide.",
label="Text to synthesize (max 300 characters)",
max_lines=5,
placeholder="Enter your text here..."
)
audio_prompt = gr.Textbox(
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
label="Reference Audio URL",
placeholder="URL to reference audio file"
)
with gr.Row():
exaggeration = gr.Slider(
0.25, 2,
step=0.05,
label="Exaggeration",
value=0.5,
info="Controls expressiveness (0.5 = neutral)"
)
cfg_weight = gr.Slider(
0.2, 1,
step=0.05,
label="CFG Weight",
value=0.5,
info="Controls pace and clarity"
)
with gr.Accordion("Advanced Settings", open=False):
temperature = gr.Slider(
0.05, 5,
step=0.05,
label="Temperature",
value=0.8,
info="Controls randomness"
)
seed = gr.Number(
value=0,
label="Seed (0 = random)",
info="Set to non-zero for reproducible results"
)
generate_btn = gr.Button("🎡 Generate Speech", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="Generated Speech")
status_text = gr.Textbox(
label="Status",
interactive=False,
placeholder="Click 'Generate Speech' to start..."
)
# Examples
gr.Examples(
examples=[
[
"Welcome to our podcast! Today we're discussing the latest developments in artificial intelligence.",
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
0.6, 0.8, 0, 0.5
],
[
"Good morning! I hope you're having a wonderful day. Let me tell you about our exciting new features.",
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
0.7, 0.9, 0, 0.6
],
[
"In today's tutorial, we'll learn how to build a machine learning model from scratch using Python.",
"https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac",
0.4, 0.7, 0, 0.4
]
],
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight]
)
def generate_speech_ui(text, prompt_url, exag, temp, seed_val, cfg):
"""Generate speech from UI"""
try:
if not text.strip():
return None, "❌ Please enter some text"
if len(text) > 300:
return None, "❌ Text too long (max 300 characters)"
start_time = time.time()
# Generate audio
sample_rate, audio_data = generate_tts_audio(
text, prompt_url, exag, temp, int(seed_val), cfg
)
generation_time = time.time() - start_time
duration = len(audio_data) / sample_rate
status = f"""βœ… Speech generated successfully!
⏱️ Generation time: {generation_time:.2f}s
🎡 Audio duration: {duration:.2f}s
πŸ“Š Sample rate: {sample_rate} Hz
πŸ”Š Audio samples: {len(audio_data):,}
"""
return (sample_rate, audio_data), status
except Exception as e:
logger.error(f"UI generation failed: {e}")
return None, f"❌ Generation failed: {str(e)}"
generate_btn.click(
fn=generate_speech_ui,
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight],
outputs=[audio_output, status_text]
)
# API Documentation
gr.Markdown("""
## πŸ”Œ API Endpoints
### POST `/api/tts/synthesize`
Generate speech from text
```json
{
"text": "Your text here",
"audio_prompt_url": "URL to reference audio",
"exaggeration": 0.5,
"temperature": 0.8,
"cfg_weight": 0.5,
"seed": 0
}
```
### GET `/api/audio/{audio_id}`
Download generated audio file
### GET `/api/audio`
List all generated audio files
### GET `/health`
Check service health
""")
# System info
model_status = "βœ… Real ChatterboxTTS" if CHATTERBOX_AVAILABLE and MODEL else "⚠️ Mock Implementation" if MODEL else "❌ Not Loaded"
gr.Markdown(f"""
### πŸ“Š System Status
- **Model**: {model_status}
- **Device**: {DEVICE}
- **Generated Files**: {len(audio_cache)}
- **ChatterboxTTS Available**: {CHATTERBOX_AVAILABLE}
{"" if CHATTERBOX_AVAILABLE else "**Note**: Install ChatterboxTTS for production use: `pip install chatterbox-tts`"}
""")
return demo
# Main execution
if __name__ == "__main__":
logger.info("πŸŽ‰ Starting ChatterboxTTS Service...")
# Model status
if CHATTERBOX_AVAILABLE and MODEL:
model_status = "βœ… Real ChatterboxTTS Loaded"
elif MODEL:
model_status = "⚠️ Mock ChatterboxTTS (Install real package for production)"
else:
model_status = "❌ No Model Loaded"
logger.info(f"Model Status: {model_status}")
logger.info(f"Device: {DEVICE}")
logger.info(f"ChatterboxTTS Available: {CHATTERBOX_AVAILABLE}")
if os.getenv("SPACE_ID"):
# Running in Hugging Face Spaces
logger.info("🏠 Running in Hugging Face Spaces")
demo = create_gradio_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
else:
# Local development - run both FastAPI and Gradio
import uvicorn
import threading
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
# Start FastAPI in background
api_thread = threading.Thread(target=run_fastapi, daemon=True)
api_thread.start()
logger.info("🌐 FastAPI: http://localhost:8000")
logger.info("πŸ“š API Docs: http://localhost:8000/docs")
# Start Gradio
demo = create_gradio_interface()
demo.launch(share=True)