import os import time import torch import random import numpy as np import soundfile as sf import tempfile import uuid import logging import requests import io from typing import Optional, Dict, Any from pathlib import Path import gradio as gr import spaces from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel # ChatterboxTTS import - you need to install this separately # For now, we'll create a mock implementation that you can replace try: from chatterbox.src.chatterbox.tts import ChatterboxTTS CHATTERBOX_AVAILABLE = True except ImportError: CHATTERBOX_AVAILABLE = False print("⚠️ ChatterboxTTS not found. Using mock implementation.") print("📦 Install ChatterboxTTS: pip install chatterbox-tts") # Mock ChatterboxTTS for demonstration class ChatterboxTTS: def __init__(self, device="cpu"): self.device = device self.sr = 24000 @classmethod def from_pretrained(cls, device): return cls(device) def to(self, device): self.device = device return self def generate(self, text, audio_prompt_path=None, exaggeration=0.5, temperature=0.8, cfg_weight=0.5): # Generate mock audio - replace this with real ChatterboxTTS duration = min(len(text) * 0.1, 10.0) t = np.linspace(0, duration, int(self.sr * duration)) # Create more realistic mock audio words = len(text.split()) freq_base = 150 + (words % 50) * 5 # Vary by content # Generate speech-like waveform audio = np.zeros_like(t) for i in range(3): # Multiple harmonics freq = freq_base * (i + 1) envelope = np.exp(-t / (duration * 0.7)) wave = 0.2 * np.sin(2 * np.pi * freq * t + i) * envelope audio += wave # Add some variation based on parameters audio *= (0.5 + exaggeration) if temperature > 1.0: noise = np.random.normal(0, 0.05, len(audio)) audio += noise return torch.tensor(audio).unsqueeze(0) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"🚀 Running on device: {DEVICE}") # Global model variable MODEL = None # Storage for generated audio AUDIO_DIR = "generated_audio" os.makedirs(AUDIO_DIR, exist_ok=True) audio_cache = {} def get_or_load_model(): """Load ChatterboxTTS model if not already loaded""" global MODEL if MODEL is None: logger.info("Loading ChatterboxTTS model...") try: MODEL = ChatterboxTTS.from_pretrained(DEVICE) if hasattr(MODEL, 'to'): MODEL.to(DEVICE) logger.info("✅ ChatterboxTTS model loaded successfully") except Exception as e: logger.error(f"❌ Error loading model: {e}") raise return MODEL def set_seed(seed: int): """Set random seed for reproducibility""" torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def generate_id(): """Generate unique ID""" return str(uuid.uuid4()) # Pydantic models for API class TTSRequest(BaseModel): text: str audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" exaggeration: Optional[float] = 0.5 temperature: Optional[float] = 0.8 cfg_weight: Optional[float] = 0.5 seed: Optional[int] = 0 class TTSResponse(BaseModel): success: bool audio_id: Optional[str] = None message: str sample_rate: Optional[int] = None duration: Optional[float] = None # Load model at startup try: if CHATTERBOX_AVAILABLE: get_or_load_model() print("✅ ChatterboxTTS model loaded successfully") else: MODEL = ChatterboxTTS.from_pretrained(DEVICE) print("⚠️ Using mock ChatterboxTTS implementation") except Exception as e: logger.error(f"Failed to load model on startup: {e}") MODEL = None @spaces.GPU def generate_tts_audio( text_input: str, audio_prompt_path_input: str, exaggeration_input: float, temperature_input: float, seed_num_input: int, cfgw_input: float ) -> tuple[int, np.ndarray]: """ Generate TTS audio using ChatterboxTTS model """ current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded") if seed_num_input != 0: set_seed(int(seed_num_input)) logger.info(f"🎵 Generating audio for: '{text_input[:50]}...'") try: wav = current_model.generate( text_input[:300], # Limit text length audio_prompt_path=audio_prompt_path_input, exaggeration=exaggeration_input, temperature=temperature_input, cfg_weight=cfgw_input, ) logger.info("✅ Audio generation complete") return (current_model.sr, wav.squeeze(0).numpy()) except Exception as e: logger.error(f"❌ Audio generation failed: {e}") raise # FastAPI app for API endpoints app = FastAPI( title="ChatterboxTTS API", description="High-quality text-to-speech synthesis using ChatterboxTTS", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """API status endpoint""" return { "service": "ChatterboxTTS API", "version": "1.0.0", "status": "operational" if MODEL else "model_loading", "model_loaded": MODEL is not None, "device": DEVICE, "endpoints": { "synthesize": "/api/tts/synthesize", "audio": "/api/audio/{audio_id}", "health": "/health" } } @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy" if MODEL else "unhealthy", "model_loaded": MODEL is not None, "device": DEVICE, "timestamp": time.time() } @app.post("/api/tts/synthesize", response_model=TTSResponse) async def synthesize_speech(request: TTSRequest): """ Synthesize speech from text """ try: if MODEL is None: raise HTTPException(status_code=503, detail="Model not loaded") if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if len(request.text) > 500: raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") start_time = time.time() # Generate audio sample_rate, audio_data = generate_tts_audio( request.text, request.audio_prompt_url, request.exaggeration, request.temperature, request.seed, request.cfg_weight ) generation_time = time.time() - start_time # Save audio file audio_id = generate_id() audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav") sf.write(audio_path, audio_data, sample_rate) # Cache audio info audio_cache[audio_id] = { "path": audio_path, "text": request.text, "sample_rate": sample_rate, "duration": len(audio_data) / sample_rate, "generated_at": time.time(), "generation_time": generation_time } logger.info(f"✅ Audio saved: {audio_id} ({generation_time:.2f}s)") return TTSResponse( success=True, audio_id=audio_id, message="Speech synthesized successfully", sample_rate=sample_rate, duration=len(audio_data) / sample_rate ) except HTTPException: raise except Exception as e: logger.error(f"❌ Synthesis failed: {e}") raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") @app.get("/api/audio/{audio_id}") async def get_audio(audio_id: str): """ Download generated audio file """ if audio_id not in audio_cache: raise HTTPException(status_code=404, detail="Audio not found") audio_info = audio_cache[audio_id] audio_path = audio_info["path"] if not os.path.exists(audio_path): raise HTTPException(status_code=404, detail="Audio file not found on disk") def iterfile(): with open(audio_path, "rb") as f: yield from f return StreamingResponse( iterfile(), media_type="audio/wav", headers={ "Content-Disposition": f"attachment; filename=tts_{audio_id}.wav" } ) @app.get("/api/audio/{audio_id}/info") async def get_audio_info(audio_id: str): """ Get audio file information """ if audio_id not in audio_cache: raise HTTPException(status_code=404, detail="Audio not found") return audio_cache[audio_id] @app.get("/api/audio") async def list_audio(): """ List all generated audio files """ return { "audio_files": [ { "audio_id": audio_id, "text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"], "duration": info["duration"], "generated_at": info["generated_at"] } for audio_id, info in audio_cache.items() ], "total": len(audio_cache) } # Gradio interface def create_gradio_interface(): """Create simple Gradio interface""" with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎵 ChatterboxTTS High-quality text-to-speech synthesis with voice cloning capabilities. """) with gr.Row(): with gr.Column(): text_input = gr.Textbox( value="Hello, this is ChatterboxTTS. I can generate natural-sounding speech from any text you provide.", label="Text to synthesize (max 300 characters)", max_lines=5, placeholder="Enter your text here..." ) audio_prompt = gr.Textbox( value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", label="Reference Audio URL", placeholder="URL to reference audio file" ) with gr.Row(): exaggeration = gr.Slider( 0.25, 2, step=0.05, label="Exaggeration", value=0.5, info="Controls expressiveness (0.5 = neutral)" ) cfg_weight = gr.Slider( 0.2, 1, step=0.05, label="CFG Weight", value=0.5, info="Controls pace and clarity" ) with gr.Accordion("Advanced Settings", open=False): temperature = gr.Slider( 0.05, 5, step=0.05, label="Temperature", value=0.8, info="Controls randomness" ) seed = gr.Number( value=0, label="Seed (0 = random)", info="Set to non-zero for reproducible results" ) generate_btn = gr.Button("🎵 Generate Speech", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Generated Speech") status_text = gr.Textbox( label="Status", interactive=False, placeholder="Click 'Generate Speech' to start..." ) # Examples gr.Examples( examples=[ [ "Welcome to our podcast! Today we're discussing the latest developments in artificial intelligence.", "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", 0.6, 0.8, 0, 0.5 ], [ "Good morning! I hope you're having a wonderful day. Let me tell you about our exciting new features.", "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", 0.7, 0.9, 0, 0.6 ], [ "In today's tutorial, we'll learn how to build a machine learning model from scratch using Python.", "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", 0.4, 0.7, 0, 0.4 ] ], inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight] ) def generate_speech_ui(text, prompt_url, exag, temp, seed_val, cfg): """Generate speech from UI""" try: if not text.strip(): return None, "❌ Please enter some text" if len(text) > 300: return None, "❌ Text too long (max 300 characters)" start_time = time.time() # Generate audio sample_rate, audio_data = generate_tts_audio( text, prompt_url, exag, temp, int(seed_val), cfg ) generation_time = time.time() - start_time duration = len(audio_data) / sample_rate status = f"""✅ Speech generated successfully! ⏱️ Generation time: {generation_time:.2f}s 🎵 Audio duration: {duration:.2f}s 📊 Sample rate: {sample_rate} Hz 🔊 Audio samples: {len(audio_data):,} """ return (sample_rate, audio_data), status except Exception as e: logger.error(f"UI generation failed: {e}") return None, f"❌ Generation failed: {str(e)}" generate_btn.click( fn=generate_speech_ui, inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight], outputs=[audio_output, status_text] ) # API Documentation gr.Markdown(""" ## 🔌 API Endpoints ### POST `/api/tts/synthesize` Generate speech from text ```json { "text": "Your text here", "audio_prompt_url": "URL to reference audio", "exaggeration": 0.5, "temperature": 0.8, "cfg_weight": 0.5, "seed": 0 } ``` ### GET `/api/audio/{audio_id}` Download generated audio file ### GET `/api/audio` List all generated audio files ### GET `/health` Check service health """) # System info model_status = "✅ Real ChatterboxTTS" if CHATTERBOX_AVAILABLE and MODEL else "⚠️ Mock Implementation" if MODEL else "❌ Not Loaded" gr.Markdown(f""" ### 📊 System Status - **Model**: {model_status} - **Device**: {DEVICE} - **Generated Files**: {len(audio_cache)} - **ChatterboxTTS Available**: {CHATTERBOX_AVAILABLE} {"" if CHATTERBOX_AVAILABLE else "**Note**: Install ChatterboxTTS for production use: `pip install chatterbox-tts`"} """) return demo # Main execution if __name__ == "__main__": logger.info("🎉 Starting ChatterboxTTS Service...") # Model status if CHATTERBOX_AVAILABLE and MODEL: model_status = "✅ Real ChatterboxTTS Loaded" elif MODEL: model_status = "⚠️ Mock ChatterboxTTS (Install real package for production)" else: model_status = "❌ No Model Loaded" logger.info(f"Model Status: {model_status}") logger.info(f"Device: {DEVICE}") logger.info(f"ChatterboxTTS Available: {CHATTERBOX_AVAILABLE}") if os.getenv("SPACE_ID"): # Running in Hugging Face Spaces logger.info("🏠 Running in Hugging Face Spaces") demo = create_gradio_interface() demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True ) else: # Local development - run both FastAPI and Gradio import uvicorn import threading def run_fastapi(): uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") # Start FastAPI in background api_thread = threading.Thread(target=run_fastapi, daemon=True) api_thread.start() logger.info("🌐 FastAPI: http://localhost:8000") logger.info("📚 API Docs: http://localhost:8000/docs") # Start Gradio demo = create_gradio_interface() demo.launch(share=True)