import os import time import torch import random import numpy as np import soundfile as sf import tempfile import uuid import logging import requests import io from typing import Optional, Dict, Any from pathlib import Path import gradio as gr import spaces from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Device configuration DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"🚀 Running on device: {DEVICE}") # Global model variable MODEL = None CHATTERBOX_AVAILABLE = False # Storage for generated audio AUDIO_DIR = "generated_audio" os.makedirs(AUDIO_DIR, exist_ok=True) audio_cache = {} def load_chatterbox_model(): """Try multiple ways to load ChatterboxTTS from Resemble AI""" global MODEL, CHATTERBOX_AVAILABLE # Method 1: Try Resemble AI ChatterboxTTS (most likely) try: from chatterbox.src.chatterbox.tts import ChatterboxTTS logger.info("✅ Found Resemble AI ChatterboxTTS in chatterbox.src.chatterbox.tts") MODEL = ChatterboxTTS.from_pretrained(DEVICE) CHATTERBOX_AVAILABLE = True return True except ImportError as e: logger.warning(f"Method 1 (Resemble AI standard path) failed: {e}") except Exception as e: logger.warning(f"Method 1 failed with error: {e}") # Method 2: Try alternative import path for Resemble AI repo try: from chatterbox.tts import ChatterboxTTS logger.info("✅ Found ChatterboxTTS in chatterbox.tts") MODEL = ChatterboxTTS.from_pretrained(DEVICE) CHATTERBOX_AVAILABLE = True return True except ImportError as e: logger.warning(f"Method 2 failed: {e}") except Exception as e: logger.warning(f"Method 2 failed with error: {e}") # Method 3: Try direct chatterbox import try: import chatterbox if hasattr(chatterbox, 'ChatterboxTTS'): MODEL = chatterbox.ChatterboxTTS.from_pretrained(DEVICE) elif hasattr(chatterbox, 'tts') and hasattr(chatterbox.tts, 'ChatterboxTTS'): MODEL = chatterbox.tts.ChatterboxTTS.from_pretrained(DEVICE) else: raise ImportError("ChatterboxTTS not found in chatterbox module") logger.info("✅ Found ChatterboxTTS via direct import") CHATTERBOX_AVAILABLE = True return True except ImportError as e: logger.warning(f"Method 3 failed: {e}") except Exception as e: logger.warning(f"Method 3 failed with error: {e}") # Method 4: Try exploring the installed package try: import chatterbox import inspect # Log what's available in the chatterbox package logger.info(f"Chatterbox module path: {chatterbox.__file__}") logger.info(f"Chatterbox contents: {dir(chatterbox)}") # Try to find ChatterboxTTS class anywhere in the module for name, obj in inspect.getmembers(chatterbox): if name == 'ChatterboxTTS' or (inspect.isclass(obj) and 'TTS' in name): logger.info(f"Found potential TTS class: {name}") MODEL = obj.from_pretrained(DEVICE) CHATTERBOX_AVAILABLE = True return True raise ImportError("ChatterboxTTS class not found in chatterbox package") except ImportError as e: logger.warning(f"Method 4 failed: {e}") except Exception as e: logger.warning(f"Method 4 failed with error: {e}") # Method 5: Check if the GitHub repo was installed correctly try: import pkg_resources try: pkg_resources.get_distribution('chatterbox') logger.info("✅ Chatterbox package is installed") except pkg_resources.DistributionNotFound: logger.warning("❌ Chatterbox package not found in installed packages") # Try to import and inspect what we got import chatterbox chatterbox_path = chatterbox.__path__[0] if hasattr(chatterbox, '__path__') else str(chatterbox.__file__) logger.info(f"Chatterbox installed at: {chatterbox_path}") # List all available modules/classes import pkgutil for importer, modname, ispkg in pkgutil.walk_packages(chatterbox.__path__, chatterbox.__name__ + "."): logger.info(f"Available module: {modname}") except Exception as e: logger.warning(f"Package inspection failed: {e}") # If we get here, the GitHub repo might have a different structure logger.error("❌ Could not load ChatterboxTTS from Resemble AI repository") logger.error("💡 The GitHub repo might have a different structure than expected") logger.error("🔗 Repository: https://github.com/resemble-ai/chatterbox.git") logger.error("📋 Check the repo's README for correct import instructions") return False def get_or_load_model(): """Load ChatterboxTTS model if not already loaded""" global MODEL if MODEL is None: logger.info("Loading ChatterboxTTS model...") success = load_chatterbox_model() if success: if hasattr(MODEL, 'to'): MODEL.to(DEVICE) logger.info("✅ ChatterboxTTS model loaded successfully") else: logger.error("❌ Failed to load ChatterboxTTS - using fallback") # Create a better fallback that shows the issue create_fallback_model() return MODEL def create_fallback_model(): """Create a fallback model that explains the issue""" global MODEL class FallbackChatterboxTTS: def __init__(self, device="cpu"): self.device = device self.sr = 24000 @classmethod def from_pretrained(cls, device): return cls(device) def to(self, device): self.device = device return self def generate(self, text, audio_prompt_path=None, exaggeration=0.5, temperature=0.8, cfg_weight=0.5): logger.warning("🚨 USING FALLBACK MODEL - Real ChatterboxTTS not found!") logger.warning(f"📝 Text to synthesize: {text[:50]}...") # Generate a more obvious fallback sound duration = 2.0 # Fixed 2 seconds t = np.linspace(0, duration, int(self.sr * duration)) # Create a distinctive "missing model" sound pattern # Three beeps to indicate this is a fallback beep_freq = 800 # Higher frequency beep beep_pattern = np.zeros_like(t) # Three short beeps for i in range(3): start_time = i * 0.6 end_time = start_time + 0.2 mask = (t >= start_time) & (t < end_time) beep_pattern[mask] = 0.3 * np.sin(2 * np.pi * beep_freq * t[mask]) return torch.tensor(beep_pattern).unsqueeze(0) MODEL = FallbackChatterboxTTS(DEVICE) def set_seed(seed: int): """Set random seed for reproducibility""" torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) def generate_id(): """Generate unique ID""" return str(uuid.uuid4()) # Pydantic models for API class TTSRequest(BaseModel): text: str audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" exaggeration: Optional[float] = 0.5 temperature: Optional[float] = 0.8 cfg_weight: Optional[float] = 0.5 seed: Optional[int] = 0 class TTSResponse(BaseModel): success: bool audio_id: Optional[str] = None message: str sample_rate: Optional[int] = None duration: Optional[float] = None # Load model at startup try: get_or_load_model() if CHATTERBOX_AVAILABLE: logger.info("✅ Real ChatterboxTTS model loaded successfully") else: logger.warning("⚠️ Using fallback model - Upload ChatterboxTTS package for real synthesis") except Exception as e: logger.error(f"Failed to load any model: {e}") MODEL = None @spaces.GPU def generate_tts_audio( text_input: str, audio_prompt_path_input: str, exaggeration_input: float, temperature_input: float, seed_num_input: int, cfgw_input: float ) -> tuple[int, np.ndarray]: """ Generate TTS audio using ChatterboxTTS model """ current_model = get_or_load_model() if current_model is None: raise RuntimeError("No TTS model available") if seed_num_input != 0: set_seed(int(seed_num_input)) logger.info(f"🎵 Generating audio for: '{text_input[:50]}...'") if not CHATTERBOX_AVAILABLE: logger.warning("🚨 USING FALLBACK - Real ChatterboxTTS not found!") logger.warning("📋 To fix: Upload your ChatterboxTTS package to this Space") try: wav = current_model.generate( text_input[:300], # Limit text length audio_prompt_path=audio_prompt_path_input, exaggeration=exaggeration_input, temperature=temperature_input, cfg_weight=cfgw_input, ) if CHATTERBOX_AVAILABLE: logger.info("✅ Real ChatterboxTTS audio generation complete") else: logger.warning("⚠️ Fallback audio generated - upload ChatterboxTTS for real synthesis") return (current_model.sr, wav.squeeze(0).numpy()) except Exception as e: logger.error(f"❌ Audio generation failed: {e}") raise # FastAPI app for API endpoints app = FastAPI( title="ChatterboxTTS API", description="High-quality text-to-speech synthesis using ChatterboxTTS", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """API status endpoint""" return { "service": "ChatterboxTTS API", "version": "1.0.0", "status": "operational" if MODEL else "model_loading", "model_loaded": MODEL is not None, "real_chatterbox": CHATTERBOX_AVAILABLE, "device": DEVICE, "message": "Real ChatterboxTTS loaded" if CHATTERBOX_AVAILABLE else "Using fallback - upload ChatterboxTTS package", "endpoints": { "synthesize": "/api/tts/synthesize", "audio": "/api/audio/{audio_id}", "health": "/health" } } @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy" if MODEL else "unhealthy", "model_loaded": MODEL is not None, "real_chatterbox": CHATTERBOX_AVAILABLE, "device": DEVICE, "timestamp": time.time(), "warning": None if CHATTERBOX_AVAILABLE else "Using fallback model - upload ChatterboxTTS for production" } @app.post("/api/tts/synthesize", response_model=TTSResponse) async def synthesize_speech(request: TTSRequest): """ Synthesize speech from text """ try: if MODEL is None: raise HTTPException(status_code=503, detail="Model not loaded") if not request.text.strip(): raise HTTPException(status_code=400, detail="Text cannot be empty") if len(request.text) > 500: raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") start_time = time.time() # Generate audio sample_rate, audio_data = generate_tts_audio( request.text, request.audio_prompt_url, request.exaggeration, request.temperature, request.seed, request.cfg_weight ) generation_time = time.time() - start_time # Save audio file audio_id = generate_id() audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav") sf.write(audio_path, audio_data, sample_rate) # Cache audio info audio_cache[audio_id] = { "path": audio_path, "text": request.text, "sample_rate": sample_rate, "duration": len(audio_data) / sample_rate, "generated_at": time.time(), "generation_time": generation_time, "real_chatterbox": CHATTERBOX_AVAILABLE } message = "Speech synthesized successfully" if not CHATTERBOX_AVAILABLE: message += " (using fallback - upload ChatterboxTTS for real synthesis)" logger.info(f"✅ Audio saved: {audio_id} ({generation_time:.2f}s)") return TTSResponse( success=True, audio_id=audio_id, message=message, sample_rate=sample_rate, duration=len(audio_data) / sample_rate ) except HTTPException: raise except Exception as e: logger.error(f"❌ Synthesis failed: {e}") raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") @app.get("/api/audio/{audio_id}") async def get_audio(audio_id: str): """ Download generated audio file """ if audio_id not in audio_cache: raise HTTPException(status_code=404, detail="Audio not found") audio_info = audio_cache[audio_id] audio_path = audio_info["path"] if not os.path.exists(audio_path): raise HTTPException(status_code=404, detail="Audio file not found on disk") def iterfile(): with open(audio_path, "rb") as f: yield from f return StreamingResponse( iterfile(), media_type="audio/wav", headers={ "Content-Disposition": f"attachment; filename=tts_{audio_id}.wav" } ) @app.get("/api/audio/{audio_id}/info") async def get_audio_info(audio_id: str): """ Get audio file information """ if audio_id not in audio_cache: raise HTTPException(status_code=404, detail="Audio not found") return audio_cache[audio_id] @app.get("/api/audio") async def list_audio(): """ List all generated audio files """ return { "audio_files": [ { "audio_id": audio_id, "text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"], "duration": info["duration"], "generated_at": info["generated_at"], "real_chatterbox": info.get("real_chatterbox", False) } for audio_id, info in audio_cache.items() ], "total": len(audio_cache) } # Gradio interface def create_gradio_interface(): """Create Gradio interface with better ChatterboxTTS status""" with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo: # Status indicator at the top if CHATTERBOX_AVAILABLE: status_color = "green" status_message = "✅ Real ChatterboxTTS Loaded - Production Ready!" else: status_color = "orange" status_message = "⚠️ Fallback Mode - Upload ChatterboxTTS Package for Real Synthesis" gr.HTML(f"""