Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import random | |
import numpy as np | |
import soundfile as sf | |
import tempfile | |
import uuid | |
import logging | |
import requests | |
import io | |
from typing import Optional, Dict, Any | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Device configuration | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"π Running on device: {DEVICE}") | |
# Global model variable | |
MODEL = None | |
CHATTERBOX_AVAILABLE = False | |
# Storage for generated audio | |
AUDIO_DIR = "generated_audio" | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
audio_cache = {} | |
def load_chatterbox_model(): | |
"""Try multiple ways to load ChatterboxTTS from Resemble AI""" | |
global MODEL, CHATTERBOX_AVAILABLE | |
# Method 1: Try Resemble AI ChatterboxTTS (most likely) | |
try: | |
from chatterbox.src.chatterbox.tts import ChatterboxTTS | |
logger.info("β Found Resemble AI ChatterboxTTS in chatterbox.src.chatterbox.tts") | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
CHATTERBOX_AVAILABLE = True | |
return True | |
except ImportError as e: | |
logger.warning(f"Method 1 (Resemble AI standard path) failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 1 failed with error: {e}") | |
# Method 2: Try alternative import path for Resemble AI repo | |
try: | |
from chatterbox.tts import ChatterboxTTS | |
logger.info("β Found ChatterboxTTS in chatterbox.tts") | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
CHATTERBOX_AVAILABLE = True | |
return True | |
except ImportError as e: | |
logger.warning(f"Method 2 failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 2 failed with error: {e}") | |
# Method 3: Try direct chatterbox import | |
try: | |
import chatterbox | |
if hasattr(chatterbox, 'ChatterboxTTS'): | |
MODEL = chatterbox.ChatterboxTTS.from_pretrained(DEVICE) | |
elif hasattr(chatterbox, 'tts') and hasattr(chatterbox.tts, 'ChatterboxTTS'): | |
MODEL = chatterbox.tts.ChatterboxTTS.from_pretrained(DEVICE) | |
else: | |
raise ImportError("ChatterboxTTS not found in chatterbox module") | |
logger.info("β Found ChatterboxTTS via direct import") | |
CHATTERBOX_AVAILABLE = True | |
return True | |
except ImportError as e: | |
logger.warning(f"Method 3 failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 3 failed with error: {e}") | |
# Method 4: Try exploring the installed package | |
try: | |
import chatterbox | |
import inspect | |
# Log what's available in the chatterbox package | |
logger.info(f"Chatterbox module path: {chatterbox.__file__}") | |
logger.info(f"Chatterbox contents: {dir(chatterbox)}") | |
# Try to find ChatterboxTTS class anywhere in the module | |
for name, obj in inspect.getmembers(chatterbox): | |
if name == 'ChatterboxTTS' or (inspect.isclass(obj) and 'TTS' in name): | |
logger.info(f"Found potential TTS class: {name}") | |
MODEL = obj.from_pretrained(DEVICE) | |
CHATTERBOX_AVAILABLE = True | |
return True | |
raise ImportError("ChatterboxTTS class not found in chatterbox package") | |
except ImportError as e: | |
logger.warning(f"Method 4 failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 4 failed with error: {e}") | |
# Method 5: Check if the GitHub repo was installed correctly | |
try: | |
import pkg_resources | |
try: | |
pkg_resources.get_distribution('chatterbox') | |
logger.info("β Chatterbox package is installed") | |
except pkg_resources.DistributionNotFound: | |
logger.warning("β Chatterbox package not found in installed packages") | |
# Try to import and inspect what we got | |
import chatterbox | |
chatterbox_path = chatterbox.__path__[0] if hasattr(chatterbox, '__path__') else str(chatterbox.__file__) | |
logger.info(f"Chatterbox installed at: {chatterbox_path}") | |
# List all available modules/classes | |
import pkgutil | |
for importer, modname, ispkg in pkgutil.walk_packages(chatterbox.__path__, chatterbox.__name__ + "."): | |
logger.info(f"Available module: {modname}") | |
except Exception as e: | |
logger.warning(f"Package inspection failed: {e}") | |
# If we get here, the GitHub repo might have a different structure | |
logger.error("β Could not load ChatterboxTTS from Resemble AI repository") | |
logger.error("π‘ The GitHub repo might have a different structure than expected") | |
logger.error("π Repository: https://github.com/resemble-ai/chatterbox.git") | |
logger.error("π Check the repo's README for correct import instructions") | |
return False | |
def get_or_load_model(): | |
"""Load ChatterboxTTS model if not already loaded""" | |
global MODEL | |
if MODEL is None: | |
logger.info("Loading ChatterboxTTS model...") | |
success = load_chatterbox_model() | |
if success: | |
if hasattr(MODEL, 'to'): | |
MODEL.to(DEVICE) | |
logger.info("β ChatterboxTTS model loaded successfully") | |
else: | |
logger.error("β Failed to load ChatterboxTTS - using fallback") | |
# Create a better fallback that shows the issue | |
create_fallback_model() | |
return MODEL | |
def create_fallback_model(): | |
"""Create a fallback model that explains the issue""" | |
global MODEL | |
class FallbackChatterboxTTS: | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.sr = 24000 | |
def from_pretrained(cls, device): | |
return cls(device) | |
def to(self, device): | |
self.device = device | |
return self | |
def generate(self, text, audio_prompt_path=None, exaggeration=0.5, | |
temperature=0.8, cfg_weight=0.5): | |
logger.warning("π¨ USING FALLBACK MODEL - Real ChatterboxTTS not found!") | |
logger.warning(f"π Text to synthesize: {text[:50]}...") | |
# Generate a more obvious fallback sound | |
duration = 2.0 # Fixed 2 seconds | |
t = np.linspace(0, duration, int(self.sr * duration)) | |
# Create a distinctive "missing model" sound pattern | |
# Three beeps to indicate this is a fallback | |
beep_freq = 800 # Higher frequency beep | |
beep_pattern = np.zeros_like(t) | |
# Three short beeps | |
for i in range(3): | |
start_time = i * 0.6 | |
end_time = start_time + 0.2 | |
mask = (t >= start_time) & (t < end_time) | |
beep_pattern[mask] = 0.3 * np.sin(2 * np.pi * beep_freq * t[mask]) | |
return torch.tensor(beep_pattern).unsqueeze(0) | |
MODEL = FallbackChatterboxTTS(DEVICE) | |
def set_seed(seed: int): | |
"""Set random seed for reproducibility""" | |
torch.manual_seed(seed) | |
if DEVICE == "cuda": | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
def generate_id(): | |
"""Generate unique ID""" | |
return str(uuid.uuid4()) | |
# Pydantic models for API | |
class TTSRequest(BaseModel): | |
text: str | |
audio_prompt_url: Optional[str] = "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" | |
exaggeration: Optional[float] = 0.5 | |
temperature: Optional[float] = 0.8 | |
cfg_weight: Optional[float] = 0.5 | |
seed: Optional[int] = 0 | |
class TTSResponse(BaseModel): | |
success: bool | |
audio_id: Optional[str] = None | |
message: str | |
sample_rate: Optional[int] = None | |
duration: Optional[float] = None | |
# Load model at startup | |
try: | |
get_or_load_model() | |
if CHATTERBOX_AVAILABLE: | |
logger.info("β Real ChatterboxTTS model loaded successfully") | |
else: | |
logger.warning("β οΈ Using fallback model - Upload ChatterboxTTS package for real synthesis") | |
except Exception as e: | |
logger.error(f"Failed to load any model: {e}") | |
MODEL = None | |
def generate_tts_audio( | |
text_input: str, | |
audio_prompt_path_input: str, | |
exaggeration_input: float, | |
temperature_input: float, | |
seed_num_input: int, | |
cfgw_input: float | |
) -> tuple[int, np.ndarray]: | |
""" | |
Generate TTS audio using ChatterboxTTS model | |
""" | |
current_model = get_or_load_model() | |
if current_model is None: | |
raise RuntimeError("No TTS model available") | |
if seed_num_input != 0: | |
set_seed(int(seed_num_input)) | |
logger.info(f"π΅ Generating audio for: '{text_input[:50]}...'") | |
if not CHATTERBOX_AVAILABLE: | |
logger.warning("π¨ USING FALLBACK - Real ChatterboxTTS not found!") | |
logger.warning("π To fix: Upload your ChatterboxTTS package to this Space") | |
try: | |
wav = current_model.generate( | |
text_input[:300], # Limit text length | |
audio_prompt_path=audio_prompt_path_input, | |
exaggeration=exaggeration_input, | |
temperature=temperature_input, | |
cfg_weight=cfgw_input, | |
) | |
if CHATTERBOX_AVAILABLE: | |
logger.info("β Real ChatterboxTTS audio generation complete") | |
else: | |
logger.warning("β οΈ Fallback audio generated - upload ChatterboxTTS for real synthesis") | |
return (current_model.sr, wav.squeeze(0).numpy()) | |
except Exception as e: | |
logger.error(f"β Audio generation failed: {e}") | |
raise | |
# FastAPI app for API endpoints | |
app = FastAPI( | |
title="ChatterboxTTS API", | |
description="High-quality text-to-speech synthesis using ChatterboxTTS", | |
version="1.0.0" | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
"""API status endpoint""" | |
return { | |
"service": "ChatterboxTTS API", | |
"version": "1.0.0", | |
"status": "operational" if MODEL else "model_loading", | |
"model_loaded": MODEL is not None, | |
"real_chatterbox": CHATTERBOX_AVAILABLE, | |
"device": DEVICE, | |
"message": "Real ChatterboxTTS loaded" if CHATTERBOX_AVAILABLE else "Using fallback - upload ChatterboxTTS package", | |
"endpoints": { | |
"synthesize": "/api/tts/synthesize", | |
"audio": "/api/audio/{audio_id}", | |
"health": "/health" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy" if MODEL else "unhealthy", | |
"model_loaded": MODEL is not None, | |
"real_chatterbox": CHATTERBOX_AVAILABLE, | |
"device": DEVICE, | |
"timestamp": time.time(), | |
"warning": None if CHATTERBOX_AVAILABLE else "Using fallback model - upload ChatterboxTTS for production" | |
} | |
async def synthesize_speech(request: TTSRequest): | |
""" | |
Synthesize speech from text | |
""" | |
try: | |
if MODEL is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
if len(request.text) > 500: | |
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
start_time = time.time() | |
# Generate audio | |
sample_rate, audio_data = generate_tts_audio( | |
request.text, | |
request.audio_prompt_url, | |
request.exaggeration, | |
request.temperature, | |
request.seed, | |
request.cfg_weight | |
) | |
generation_time = time.time() - start_time | |
# Save audio file | |
audio_id = generate_id() | |
audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav") | |
sf.write(audio_path, audio_data, sample_rate) | |
# Cache audio info | |
audio_cache[audio_id] = { | |
"path": audio_path, | |
"text": request.text, | |
"sample_rate": sample_rate, | |
"duration": len(audio_data) / sample_rate, | |
"generated_at": time.time(), | |
"generation_time": generation_time, | |
"real_chatterbox": CHATTERBOX_AVAILABLE | |
} | |
message = "Speech synthesized successfully" | |
if not CHATTERBOX_AVAILABLE: | |
message += " (using fallback - upload ChatterboxTTS for real synthesis)" | |
logger.info(f"β Audio saved: {audio_id} ({generation_time:.2f}s)") | |
return TTSResponse( | |
success=True, | |
audio_id=audio_id, | |
message=message, | |
sample_rate=sample_rate, | |
duration=len(audio_data) / sample_rate | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"β Synthesis failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
async def get_audio(audio_id: str): | |
""" | |
Download generated audio file | |
""" | |
if audio_id not in audio_cache: | |
raise HTTPException(status_code=404, detail="Audio not found") | |
audio_info = audio_cache[audio_id] | |
audio_path = audio_info["path"] | |
if not os.path.exists(audio_path): | |
raise HTTPException(status_code=404, detail="Audio file not found on disk") | |
def iterfile(): | |
with open(audio_path, "rb") as f: | |
yield from f | |
return StreamingResponse( | |
iterfile(), | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": f"attachment; filename=tts_{audio_id}.wav" | |
} | |
) | |
async def get_audio_info(audio_id: str): | |
""" | |
Get audio file information | |
""" | |
if audio_id not in audio_cache: | |
raise HTTPException(status_code=404, detail="Audio not found") | |
return audio_cache[audio_id] | |
async def list_audio(): | |
""" | |
List all generated audio files | |
""" | |
return { | |
"audio_files": [ | |
{ | |
"audio_id": audio_id, | |
"text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"], | |
"duration": info["duration"], | |
"generated_at": info["generated_at"], | |
"real_chatterbox": info.get("real_chatterbox", False) | |
} | |
for audio_id, info in audio_cache.items() | |
], | |
"total": len(audio_cache) | |
} | |
# Gradio interface | |
def create_gradio_interface(): | |
"""Create Gradio interface with better ChatterboxTTS status""" | |
with gr.Blocks(title="ChatterboxTTS", theme=gr.themes.Soft()) as demo: | |
# Status indicator at the top | |
if CHATTERBOX_AVAILABLE: | |
status_color = "green" | |
status_message = "β Real ChatterboxTTS Loaded - Production Ready!" | |
else: | |
status_color = "orange" | |
status_message = "β οΈ Fallback Mode - Upload ChatterboxTTS Package for Real Synthesis" | |
gr.HTML(f""" | |
<div style="background-color: {status_color}; color: white; padding: 10px; border-radius: 5px; margin-bottom: 20px;"> | |
<h3 style="margin: 0;">{status_message}</h3> | |
</div> | |
""") | |
gr.Markdown(""" | |
# π΅ ChatterboxTTS | |
High-quality text-to-speech synthesis with voice cloning capabilities. | |
""") | |
if not CHATTERBOX_AVAILABLE: | |
gr.Markdown(""" | |
### π¨ Currently Using Fallback Model | |
You're hearing beep sounds because the real ChatterboxTTS isn't loaded. | |
**The Resemble AI ChatterboxTTS from GitHub should auto-install from requirements.txt.** | |
If you're still seeing this message: | |
1. **Check build logs** for any installation errors | |
2. **Verify requirements.txt** contains: `git+https://github.com/resemble-ai/chatterbox.git` | |
3. **Restart the Space** if needed | |
4. **Check logs** for import errors | |
π GitHub repo being used: https://github.com/resemble-ai/chatterbox.git | |
If the GitHub installation fails, you can alternatively upload the package manually. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
value="Hello, this is ChatterboxTTS. I can generate natural-sounding speech from any text you provide.", | |
label="Text to synthesize (max 300 characters)", | |
max_lines=5, | |
placeholder="Enter your text here..." | |
) | |
audio_prompt = gr.Textbox( | |
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
label="Reference Audio URL", | |
placeholder="URL to reference audio file" | |
) | |
with gr.Row(): | |
exaggeration = gr.Slider( | |
0.25, 2, | |
step=0.05, | |
label="Exaggeration", | |
value=0.5, | |
info="Controls expressiveness (0.5 = neutral)" | |
) | |
cfg_weight = gr.Slider( | |
0.2, 1, | |
step=0.05, | |
label="CFG Weight", | |
value=0.5, | |
info="Controls pace and clarity" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature = gr.Slider( | |
0.05, 5, | |
step=0.05, | |
label="Temperature", | |
value=0.8, | |
info="Controls randomness" | |
) | |
seed = gr.Number( | |
value=0, | |
label="Seed (0 = random)", | |
info="Set to non-zero for reproducible results" | |
) | |
generate_btn = gr.Button("π΅ Generate Speech", variant="primary") | |
with gr.Column(): | |
audio_output = gr.Audio(label="Generated Speech") | |
status_text = gr.Textbox( | |
label="Status", | |
interactive=False, | |
placeholder="Click 'Generate Speech' to start..." | |
) | |
def generate_speech_ui(text, prompt_url, exag, temp, seed_val, cfg): | |
"""Generate speech from UI""" | |
try: | |
if not text.strip(): | |
return None, "β Please enter some text" | |
if len(text) > 300: | |
return None, "β Text too long (max 300 characters)" | |
start_time = time.time() | |
# Generate audio | |
sample_rate, audio_data = generate_tts_audio( | |
text, prompt_url, exag, temp, int(seed_val), cfg | |
) | |
generation_time = time.time() - start_time | |
duration = len(audio_data) / sample_rate | |
if CHATTERBOX_AVAILABLE: | |
status = f"""β Real ChatterboxTTS synthesis completed! | |
β±οΈ Generation time: {generation_time:.2f}s | |
π΅ Audio duration: {duration:.2f}s | |
π Sample rate: {sample_rate} Hz | |
π Audio samples: {len(audio_data):,} | |
""" | |
else: | |
status = f"""β οΈ Fallback audio generated (beep sound) | |
π¨ This is NOT real speech synthesis! | |
π¦ Upload ChatterboxTTS package for real synthesis | |
β±οΈ Generation time: {generation_time:.2f}s | |
π΅ Audio duration: {duration:.2f}s | |
π‘ To fix: Upload your ChatterboxTTS files to this Space | |
""" | |
return (sample_rate, audio_data), status | |
except Exception as e: | |
logger.error(f"UI generation failed: {e}") | |
return None, f"β Generation failed: {str(e)}" | |
generate_btn.click( | |
fn=generate_speech_ui, | |
inputs=[text_input, audio_prompt, exaggeration, temperature, seed, cfg_weight], | |
outputs=[audio_output, status_text] | |
) | |
# System info with better warnings | |
model_status = "β Real ChatterboxTTS" if CHATTERBOX_AVAILABLE else "β οΈ Fallback Model (Beep Sounds)" | |
chatterbox_status = "Available" if CHATTERBOX_AVAILABLE else "Missing - Upload Package" | |
gr.Markdown(f""" | |
### π System Status | |
- **Model**: {model_status} | |
- **Device**: {DEVICE} | |
- **Generated Files**: {len(audio_cache)} | |
- **ChatterboxTTS**: {chatterbox_status} | |
{'''### π Production Ready! | |
Your ChatterboxTTS model is loaded and working correctly.''' if CHATTERBOX_AVAILABLE else '''### β οΈ Action Required | |
**You're hearing beep sounds because ChatterboxTTS isn't loaded.** | |
**To fix this:** | |
1. Upload your ChatterboxTTS package to this Space | |
2. Ensure proper directory structure with `__init__.py` files | |
3. Restart the Space | |
The current fallback generates beeps to indicate missing package.'''} | |
""") | |
return demo | |
# Main execution | |
if __name__ == "__main__": | |
logger.info("π Starting ChatterboxTTS Service...") | |
# Model status | |
if CHATTERBOX_AVAILABLE and MODEL: | |
model_status = "β Real ChatterboxTTS Loaded" | |
elif MODEL: | |
model_status = "β οΈ Fallback Model (Upload ChatterboxTTS package for real synthesis)" | |
else: | |
model_status = "β No Model Loaded" | |
logger.info(f"Model Status: {model_status}") | |
logger.info(f"Device: {DEVICE}") | |
logger.info(f"ChatterboxTTS Available: {CHATTERBOX_AVAILABLE}") | |
if not CHATTERBOX_AVAILABLE: | |
logger.warning("π¨ IMPORTANT: Upload your ChatterboxTTS package to enable real synthesis!") | |
logger.warning("π Expected location: ./chatterbox/src/chatterbox/tts.py") | |
if os.getenv("SPACE_ID"): | |
# Running in Hugging Face Spaces | |
logger.info("π Running in Hugging Face Spaces") | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |
else: | |
# Local development - run both FastAPI and Gradio | |
import uvicorn | |
import threading | |
def run_fastapi(): | |
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") | |
# Start FastAPI in background | |
api_thread = threading.Thread(target=run_fastapi, daemon=True) | |
api_thread.start() | |
logger.info("π FastAPI: http://localhost:8000") | |
logger.info("π API Docs: http://localhost:8000/docs") | |
# Start Gradio | |
demo = create_gradio_interface() | |
demo.launch(share=True) |