Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch | |
import random | |
import numpy as np | |
import soundfile as sf | |
import tempfile | |
import uuid | |
import logging | |
import requests | |
import io | |
import json | |
import base64 | |
from typing import Optional, Dict, Any, List | |
from pathlib import Path | |
import gradio as gr | |
import spaces | |
from fastapi import FastAPI, HTTPException, UploadFile, File | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Device configuration | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"π Running on device: {DEVICE}") | |
# Global model variable | |
MODEL = None | |
CHATTERBOX_AVAILABLE = False | |
# Storage directories - use persistent storage if available | |
if os.path.exists("/data"): | |
# Hugging Face Spaces persistent storage | |
VOICES_DIR = "/data/custom_voices" | |
AUDIO_DIR = "/data/generated_audio" | |
logger.info("β Using Hugging Face Spaces persistent storage (/data)") | |
else: | |
# Fallback to local storage | |
VOICES_DIR = "custom_voices" | |
AUDIO_DIR = "generated_audio" | |
logger.warning("β οΈ Using local storage (voices will not persist)") | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
os.makedirs(VOICES_DIR, exist_ok=True) | |
# Voice storage | |
audio_cache = {} | |
voice_library = {} | |
# Default/Built-in voices | |
BUILTIN_VOICES = { | |
"female_default": { | |
"voice_id": "female_default", | |
"name": "Female Default", | |
"description": "Professional female voice", | |
"audio_url": "https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
"type": "builtin", | |
"created_at": "2024-01-01T00:00:00Z" | |
}, | |
"male_professional": { | |
"voice_id": "male_professional", | |
"name": "Male Professional", | |
"description": "Confident male voice", | |
"audio_url": "https://storage.googleapis.com/chatterbox-demo-samples/prompts/male_professional.flac", | |
"type": "builtin", | |
"created_at": "2024-01-01T00:00:00Z" | |
} | |
} | |
def encode_audio_to_base64(audio_data, sample_rate): | |
"""Encode audio data to base64 string for storage""" | |
try: | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
sf.write(temp_file.name, audio_data, sample_rate) | |
# Read as bytes and encode | |
with open(temp_file.name, 'rb') as f: | |
audio_bytes = f.read() | |
# Cleanup temp file | |
os.unlink(temp_file.name) | |
# Encode to base64 | |
return base64.b64encode(audio_bytes).decode('utf-8') | |
except Exception as e: | |
logger.error(f"Error encoding audio: {e}") | |
return None | |
def decode_audio_from_base64(base64_string): | |
"""Decode base64 string back to audio file""" | |
try: | |
# Decode base64 | |
audio_bytes = base64.b64decode(base64_string.encode('utf-8')) | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
temp_file.write(audio_bytes) | |
temp_file.close() | |
return temp_file.name | |
except Exception as e: | |
logger.error(f"Error decoding audio: {e}") | |
return None | |
def load_voice_library(): | |
"""Load saved custom voices from persistent storage""" | |
global voice_library | |
voice_library = BUILTIN_VOICES.copy() | |
voices_json_path = os.path.join(VOICES_DIR, "voices.json") | |
try: | |
if os.path.exists(voices_json_path): | |
with open(voices_json_path, 'r', encoding='utf-8') as f: | |
custom_voices = json.load(f) | |
voice_library.update(custom_voices) | |
logger.info(f"β Loaded {len(custom_voices)} custom voices from persistent storage") | |
else: | |
logger.info("π No existing voice library found, starting fresh") | |
# Log voice library status | |
total_voices = len(voice_library) | |
custom_count = len([v for v in voice_library.values() if v.get("type") == "custom"]) | |
builtin_count = len([v for v in voice_library.values() if v.get("type") == "builtin"]) | |
logger.info(f"π Voice Library: {total_voices} total ({builtin_count} builtin, {custom_count} custom)") | |
except Exception as e: | |
logger.error(f"β Error loading voice library: {e}") | |
logger.info("π Starting with builtin voices only") | |
def save_voice_library(): | |
"""Save custom voices to persistent storage""" | |
try: | |
# Only save custom voices (not builtin) | |
custom_voices = {k: v for k, v in voice_library.items() if v.get("type") != "builtin"} | |
voices_json_path = os.path.join(VOICES_DIR, "voices.json") | |
# Ensure directory exists | |
os.makedirs(os.path.dirname(voices_json_path), exist_ok=True) | |
with open(voices_json_path, 'w', encoding='utf-8') as f: | |
json.dump(custom_voices, f, ensure_ascii=False, indent=2) | |
logger.info(f"β Saved {len(custom_voices)} custom voices to persistent storage") | |
logger.info(f"π Storage location: {voices_json_path}") | |
# Verify the save worked | |
if os.path.exists(voices_json_path): | |
file_size = os.path.getsize(voices_json_path) | |
logger.info(f"π Voice library file size: {file_size} bytes") | |
except Exception as e: | |
logger.error(f"β Error saving voice library: {e}") | |
logger.error(f"π Attempted path: {voices_json_path}") | |
def create_voice_from_audio(audio_file, voice_name, voice_description="Custom voice"): | |
"""Create a new voice from uploaded audio with persistent storage""" | |
try: | |
voice_id = f"voice_{int(time.time())}_{uuid.uuid4().hex[:8]}" | |
# Handle different audio input formats | |
if isinstance(audio_file, tuple): | |
# Gradio audio format (sample_rate, audio_data) | |
sample_rate, audio_data = audio_file | |
else: | |
# File path - load the audio | |
audio_data, sample_rate = sf.read(audio_file) | |
# Encode audio to base64 for persistent storage | |
audio_base64 = encode_audio_to_base64(audio_data, sample_rate) | |
if audio_base64 is None: | |
raise ValueError("Failed to encode audio") | |
# Create voice entry with embedded audio | |
voice_entry = { | |
"voice_id": voice_id, | |
"name": voice_name, | |
"description": voice_description, | |
"audio_base64": audio_base64, # Store audio as base64 | |
"sample_rate": int(sample_rate), | |
"type": "custom", | |
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"), | |
"audio_duration": len(audio_data) / sample_rate | |
} | |
# Add to voice library | |
voice_library[voice_id] = voice_entry | |
# Save to persistent storage | |
save_voice_library() | |
logger.info(f"β Created persistent voice: {voice_name} ({voice_id})") | |
logger.info(f"π΅ Audio: {len(audio_data)} samples, {sample_rate}Hz, {voice_entry['audio_duration']:.2f}s") | |
return voice_id, voice_entry | |
except Exception as e: | |
logger.error(f"β Error creating voice: {e}") | |
return None, None | |
def download_audio_from_url(url): | |
"""Download audio from URL and save to temporary file""" | |
try: | |
logger.info(f"π₯ Downloading reference audio from: {url}") | |
response = requests.get(url, timeout=30, headers={ | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' | |
}) | |
if response.status_code == 200: | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(suffix=".flac", delete=False) | |
temp_file.write(response.content) | |
temp_file.close() | |
logger.info(f"β Audio downloaded to: {temp_file.name}") | |
return temp_file.name | |
else: | |
logger.error(f"β HTTP {response.status_code} when downloading audio") | |
return None | |
except Exception as e: | |
logger.error(f"β Error downloading audio from URL: {e}") | |
return None | |
def get_voice_audio_path(voice_id): | |
"""Get the audio path for a voice (decode from base64 if custom, download if builtin)""" | |
if voice_id not in voice_library: | |
return None | |
voice_info = voice_library[voice_id] | |
# If it's a custom voice with base64 audio | |
if voice_info.get("type") == "custom" and "audio_base64" in voice_info: | |
# Decode base64 to temporary file | |
temp_path = decode_audio_from_base64(voice_info["audio_base64"]) | |
if temp_path: | |
logger.info(f"β Decoded custom voice audio: {voice_info['name']}") | |
return temp_path | |
else: | |
logger.warning(f"β οΈ Failed to decode audio for voice {voice_id}") | |
return None | |
# If it's a legacy custom voice with file path (for backward compatibility) | |
elif voice_info.get("type") == "custom" and "audio_path" in voice_info: | |
audio_path = voice_info["audio_path"] | |
if os.path.exists(audio_path): | |
return audio_path | |
else: | |
logger.warning(f"β οΈ Voice audio file not found: {audio_path}") | |
return None | |
# If it's a builtin voice with URL | |
elif voice_info.get("type") == "builtin" and "audio_url" in voice_info: | |
return download_audio_from_url(voice_info["audio_url"]) | |
return None | |
def load_chatterbox_model(): | |
"""Try multiple ways to load ChatterboxTTS from Resemble AI""" | |
global MODEL, CHATTERBOX_AVAILABLE | |
# Method 1: Try Resemble AI ChatterboxTTS (most likely) | |
try: | |
from chatterbox.src.chatterbox.tts import ChatterboxTTS | |
logger.info("β Found Resemble AI ChatterboxTTS in chatterbox.src.chatterbox.tts") | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
CHATTERBOX_AVAILABLE = True | |
return True | |
except ImportError as e: | |
logger.warning(f"Method 1 (Resemble AI standard path) failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 1 failed with error: {e}") | |
# Method 2: Try alternative import path for Resemble AI repo | |
try: | |
from chatterbox.tts import ChatterboxTTS | |
logger.info("β Found ChatterboxTTS in chatterbox.tts") | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
CHATTERBOX_AVAILABLE = True | |
return True | |
except ImportError as e: | |
logger.warning(f"Method 2 failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 2 failed with error: {e}") | |
# Method 3: Try direct chatterbox import | |
try: | |
import chatterbox | |
if hasattr(chatterbox, 'ChatterboxTTS'): | |
MODEL = chatterbox.ChatterboxTTS.from_pretrained(DEVICE) | |
elif hasattr(chatterbox, 'tts') and hasattr(chatterbox.tts, 'ChatterboxTTS'): | |
MODEL = chatterbox.tts.ChatterboxTTS.from_pretrained(DEVICE) | |
else: | |
raise ImportError("ChatterboxTTS not found in chatterbox module") | |
logger.info("β Found ChatterboxTTS via direct import") | |
CHATTERBOX_AVAILABLE = True | |
return True | |
except ImportError as e: | |
logger.warning(f"Method 3 failed: {e}") | |
except Exception as e: | |
logger.warning(f"Method 3 failed with error: {e}") | |
# If we get here, the GitHub repo might have a different structure | |
logger.error("β Could not load ChatterboxTTS from Resemble AI repository") | |
logger.error("π‘ The GitHub repo might have a different structure than expected") | |
logger.error("π Repository: https://github.com/resemble-ai/chatterbox.git") | |
logger.error("π Check the repo's README for correct import instructions") | |
return False | |
def get_or_load_model(): | |
"""Load ChatterboxTTS model if not already loaded""" | |
global MODEL | |
if MODEL is None: | |
logger.info("Loading ChatterboxTTS model...") | |
success = load_chatterbox_model() | |
if success: | |
if hasattr(MODEL, 'to'): | |
MODEL.to(DEVICE) | |
logger.info("β ChatterboxTTS model loaded successfully") | |
else: | |
logger.error("β Failed to load ChatterboxTTS - using fallback") | |
create_fallback_model() | |
return MODEL | |
def create_fallback_model(): | |
"""Create a fallback model that explains the issue""" | |
global MODEL | |
class FallbackChatterboxTTS: | |
def __init__(self, device="cpu"): | |
self.device = device | |
self.sr = 24000 | |
def from_pretrained(cls, device): | |
return cls(device) | |
def to(self, device): | |
self.device = device | |
return self | |
def generate(self, text, audio_prompt_path=None, exaggeration=0.5, | |
temperature=0.8, cfg_weight=0.5): | |
logger.warning("π¨ USING FALLBACK MODEL - Real ChatterboxTTS not found!") | |
logger.warning(f"π Text to synthesize: {text[:50]}...") | |
# Generate a more obvious fallback sound | |
duration = 2.0 # Fixed 2 seconds | |
t = np.linspace(0, duration, int(self.sr * duration)) | |
# Create a distinctive "missing model" sound pattern | |
# Three beeps to indicate this is a fallback | |
beep_freq = 800 # Higher frequency beep | |
beep_pattern = np.zeros_like(t) | |
# Three short beeps | |
for i in range(3): | |
start_time = i * 0.6 | |
end_time = start_time + 0.2 | |
mask = (t >= start_time) & (t < end_time) | |
beep_pattern[mask] = 0.3 * np.sin(2 * np.pi * beep_freq * t[mask]) | |
return torch.tensor(beep_pattern).unsqueeze(0) | |
MODEL = FallbackChatterboxTTS(DEVICE) | |
def set_seed(seed: int): | |
"""Set random seed for reproducibility""" | |
torch.manual_seed(seed) | |
if DEVICE == "cuda": | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
def generate_id(): | |
"""Generate unique ID""" | |
return str(uuid.uuid4()) | |
# Load voice library at startup | |
load_voice_library() | |
# Pydantic models for API | |
class TTSRequest(BaseModel): | |
text: str | |
voice_id: Optional[str] = "female_default" | |
exaggeration: Optional[float] = 0.5 | |
temperature: Optional[float] = 0.8 | |
cfg_weight: Optional[float] = 0.5 | |
seed: Optional[int] = 0 | |
class VoiceCreateRequest(BaseModel): | |
voice_name: str | |
voice_description: Optional[str] = "Custom voice" | |
class VoiceInfo(BaseModel): | |
voice_id: str | |
name: str | |
description: str | |
type: str | |
created_at: str | |
class TTSResponse(BaseModel): | |
success: bool | |
audio_id: Optional[str] = None | |
message: str | |
sample_rate: Optional[int] = None | |
duration: Optional[float] = None | |
# Load model at startup | |
try: | |
get_or_load_model() | |
if CHATTERBOX_AVAILABLE: | |
logger.info("β Real ChatterboxTTS model loaded successfully") | |
else: | |
logger.warning("β οΈ Using fallback model - Upload ChatterboxTTS package for real synthesis") | |
except Exception as e: | |
logger.error(f"Failed to load any model: {e}") | |
MODEL = None | |
def generate_tts_audio( | |
text_input: str, | |
voice_id: str, | |
exaggeration_input: float, | |
temperature_input: float, | |
seed_num_input: int, | |
cfgw_input: float | |
) -> tuple[int, np.ndarray]: | |
""" | |
Generate TTS audio using ChatterboxTTS model with voice ID | |
""" | |
current_model = get_or_load_model() | |
if current_model is None: | |
raise RuntimeError("No TTS model available") | |
if seed_num_input != 0: | |
set_seed(int(seed_num_input)) | |
logger.info(f"π΅ Generating audio for: '{text_input[:50]}...'") | |
logger.info(f"π Using voice: {voice_id}") | |
if not CHATTERBOX_AVAILABLE: | |
logger.warning("π¨ USING FALLBACK - Real ChatterboxTTS not found!") | |
# Get audio path for the voice | |
audio_prompt_path = get_voice_audio_path(voice_id) | |
temp_audio_file = None | |
try: | |
# Get audio path for the voice | |
audio_prompt_path = get_voice_audio_path(voice_id) | |
temp_audio_file = None | |
# Check if we got a temporary file (from base64 decode or URL download) | |
if audio_prompt_path and (audio_prompt_path.startswith('/tmp/') or 'temp' in audio_prompt_path): | |
temp_audio_file = audio_prompt_path | |
if audio_prompt_path: | |
voice_name = voice_library.get(voice_id, {}).get("name", voice_id) | |
logger.info(f"β Using voice '{voice_name}' audio: {audio_prompt_path}") | |
else: | |
logger.warning(f"β οΈ Could not load audio for voice {voice_id}, using default") | |
# Generate audio | |
wav = current_model.generate( | |
text_input[:300], # Limit text length | |
audio_prompt_path=audio_prompt_path, | |
exaggeration=exaggeration_input, | |
temperature=temperature_input, | |
cfg_weight=cfgw_input, | |
) | |
if CHATTERBOX_AVAILABLE: | |
logger.info("β Real ChatterboxTTS audio generation complete") | |
else: | |
logger.warning("β οΈ Fallback audio generated - upload ChatterboxTTS for real synthesis") | |
return (current_model.sr, wav.squeeze(0).numpy()) | |
except Exception as e: | |
logger.error(f"β Audio generation failed: {e}") | |
raise | |
finally: | |
# Clean up temporary file (only if it's a downloaded URL or decoded audio) | |
if temp_audio_file and os.path.exists(temp_audio_file): | |
try: | |
os.unlink(temp_audio_file) | |
logger.info(f"ποΈ Cleaned up temporary file: {temp_audio_file}") | |
except: | |
pass | |
# FastAPI app for API endpoints | |
app = FastAPI( | |
title="ChatterboxTTS Voice Manager API", | |
description="Advanced text-to-speech with voice cloning and management", | |
version="2.0.0" | |
) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def root(): | |
"""API status endpoint""" | |
return { | |
"service": "ChatterboxTTS Voice Manager API", | |
"version": "2.0.0", | |
"status": "operational" if MODEL else "model_loading", | |
"model_loaded": MODEL is not None, | |
"real_chatterbox": CHATTERBOX_AVAILABLE, | |
"device": DEVICE, | |
"voices_available": len(voice_library), | |
"message": "Real ChatterboxTTS loaded" if CHATTERBOX_AVAILABLE else "Using fallback - upload ChatterboxTTS package", | |
"endpoints": { | |
"synthesize": "/api/tts/synthesize", | |
"voices": "/api/voices", | |
"create_voice": "/api/voices/create", | |
"audio": "/api/audio/{audio_id}", | |
"health": "/health" | |
} | |
} | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy" if MODEL else "unhealthy", | |
"model_loaded": MODEL is not None, | |
"real_chatterbox": CHATTERBOX_AVAILABLE, | |
"device": DEVICE, | |
"voices_total": len(voice_library), | |
"timestamp": time.time(), | |
"warning": None if CHATTERBOX_AVAILABLE else "Using fallback model - upload ChatterboxTTS for production" | |
} | |
async def get_voices(): | |
"""Get all available voices""" | |
voices = [] | |
for voice_id, voice_info in voice_library.items(): | |
voices.append(VoiceInfo( | |
voice_id=voice_id, | |
name=voice_info["name"], | |
description=voice_info["description"], | |
type=voice_info["type"], | |
created_at=voice_info["created_at"] | |
)) | |
return { | |
"voices": voices, | |
"total": len(voices), | |
"builtin": len([v for v in voices if v.type == "builtin"]), | |
"custom": len([v for v in voices if v.type == "custom"]) | |
} | |
async def create_voice_api( | |
voice_name: str, | |
voice_description: str = "Custom voice", | |
audio_file: UploadFile = File(...) | |
): | |
"""Create a new voice from uploaded audio""" | |
try: | |
# Read uploaded file | |
audio_data = await audio_file.read() | |
# Save to temporary file for processing | |
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
temp_file.write(audio_data) | |
temp_file.close() | |
# Create voice | |
voice_id, voice_entry = create_voice_from_audio( | |
temp_file.name, | |
voice_name, | |
voice_description | |
) | |
# Cleanup temp file | |
os.unlink(temp_file.name) | |
if voice_id: | |
return { | |
"success": True, | |
"voice_id": voice_id, | |
"message": f"Voice '{voice_name}' created successfully", | |
"voice_info": voice_entry | |
} | |
else: | |
raise HTTPException(status_code=500, detail="Failed to create voice") | |
except Exception as e: | |
logger.error(f"β Voice creation failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Voice creation failed: {str(e)}") | |
async def delete_voice(voice_id: str): | |
"""Delete a custom voice""" | |
if voice_id not in voice_library: | |
raise HTTPException(status_code=404, detail="Voice not found") | |
voice_info = voice_library[voice_id] | |
if voice_info.get("type") == "builtin": | |
raise HTTPException(status_code=400, detail="Cannot delete builtin voices") | |
try: | |
# Delete legacy audio file if it exists | |
if "audio_path" in voice_info and os.path.exists(voice_info["audio_path"]): | |
os.unlink(voice_info["audio_path"]) | |
# Remove from library | |
voice_name = voice_info["name"] | |
del voice_library[voice_id] | |
save_voice_library() | |
logger.info(f"β Deleted voice: {voice_name} ({voice_id})") | |
return { | |
"success": True, | |
"message": f"Voice '{voice_name}' deleted successfully" | |
} | |
except Exception as e: | |
logger.error(f"β Voice deletion failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Voice deletion failed: {str(e)}") | |
async def synthesize_speech(request: TTSRequest): | |
""" | |
Synthesize speech from text using voice ID | |
""" | |
try: | |
if MODEL is None: | |
raise HTTPException(status_code=503, detail="Model not loaded") | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
if len(request.text) > 500: | |
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") | |
if request.voice_id not in voice_library: | |
raise HTTPException(status_code=404, detail=f"Voice '{request.voice_id}' not found") | |
start_time = time.time() | |
# Generate audio using voice ID | |
sample_rate, audio_data = generate_tts_audio( | |
request.text, | |
request.voice_id, | |
request.exaggeration, | |
request.temperature, | |
request.seed, | |
request.cfg_weight | |
) | |
generation_time = time.time() - start_time | |
# Save audio file | |
audio_id = generate_id() | |
audio_path = os.path.join(AUDIO_DIR, f"{audio_id}.wav") | |
sf.write(audio_path, audio_data, sample_rate) | |
# Cache audio info | |
voice_name = voice_library[request.voice_id]["name"] | |
audio_cache[audio_id] = { | |
"path": audio_path, | |
"text": request.text, | |
"voice_id": request.voice_id, | |
"voice_name": voice_name, | |
"sample_rate": sample_rate, | |
"duration": len(audio_data) / sample_rate, | |
"generated_at": time.time(), | |
"generation_time": generation_time, | |
"real_chatterbox": CHATTERBOX_AVAILABLE | |
} | |
message = f"Speech synthesized successfully using voice '{voice_name}'" | |
if not CHATTERBOX_AVAILABLE: | |
message += " (using fallback - upload ChatterboxTTS for real synthesis)" | |
logger.info(f"β Audio saved: {audio_id} ({generation_time:.2f}s) with voice '{voice_name}'") | |
return TTSResponse( | |
success=True, | |
audio_id=audio_id, | |
message=message, | |
sample_rate=sample_rate, | |
duration=len(audio_data) / sample_rate | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"β Synthesis failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") | |
async def get_audio(audio_id: str): | |
"""Download generated audio file""" | |
if audio_id not in audio_cache: | |
raise HTTPException(status_code=404, detail="Audio not found") | |
audio_info = audio_cache[audio_id] | |
audio_path = audio_info["path"] | |
if not os.path.exists(audio_path): | |
raise HTTPException(status_code=404, detail="Audio file not found on disk") | |
def iterfile(): | |
with open(audio_path, "rb") as f: | |
yield from f | |
return StreamingResponse( | |
iterfile(), | |
media_type="audio/wav", | |
headers={ | |
"Content-Disposition": f"attachment; filename=tts_{audio_id}.wav" | |
} | |
) | |
async def get_audio_info(audio_id: str): | |
"""Get audio file information""" | |
if audio_id not in audio_cache: | |
raise HTTPException(status_code=404, detail="Audio not found") | |
return audio_cache[audio_id] | |
async def list_audio(): | |
"""List all generated audio files""" | |
return { | |
"audio_files": [ | |
{ | |
"audio_id": audio_id, | |
"text": info["text"][:50] + "..." if len(info["text"]) > 50 else info["text"], | |
"voice_name": info.get("voice_name", "Unknown"), | |
"duration": info["duration"], | |
"generated_at": info["generated_at"], | |
"real_chatterbox": info.get("real_chatterbox", False) | |
} | |
for audio_id, info in audio_cache.items() | |
], | |
"total": len(audio_cache) | |
} | |
# Gradio interface | |
def create_gradio_interface(): | |
"""Create Gradio interface with voice management""" | |
def get_voice_choices(): | |
"""Get voice choices for dropdown""" | |
choices = [] | |
for voice_id, voice_info in voice_library.items(): | |
voice_type = "π§" if voice_info["type"] == "builtin" else "π" | |
choices.append((f"{voice_type} {voice_info['name']} - {voice_info['description']}", voice_id)) | |
return choices | |
def refresh_voice_choices(): | |
"""Refresh voice dropdown""" | |
return gr.update(choices=get_voice_choices()) | |
def create_voice_ui(voice_name, voice_description, audio_file): | |
"""Create voice from UI""" | |
try: | |
if not voice_name.strip(): | |
return "β Please enter a voice name", gr.update() | |
if audio_file is None: | |
return "β Please upload an audio file", gr.update() | |
voice_id, voice_entry = create_voice_from_audio( | |
audio_file, | |
voice_name.strip(), | |
voice_description.strip() or "Custom voice" | |
) | |
if voice_id: | |
updated_choices = get_voice_choices() | |
return ( | |
f"β Voice '{voice_name}' created successfully!\n" | |
f"π Voice ID: {voice_id}\n" | |
f"π Audio saved and ready to use\n" | |
f"π Available in voice selection dropdown", | |
gr.update(choices=updated_choices, value=voice_id) | |
) | |
else: | |
return "β Failed to create voice", gr.update() | |
except Exception as e: | |
logger.error(f"UI voice creation failed: {e}") | |
return f"β Voice creation failed: {str(e)}", gr.update() | |
def generate_speech_ui(text, voice_id, exag, temp, seed_val, cfg): | |
"""Generate speech from UI using voice ID""" | |
try: | |
if not text.strip(): | |
return None, "β Please enter some text" | |
if len(text) > 300: | |
return None, "β Text too long (max 300 characters)" | |
if not voice_id or voice_id not in voice_library: | |
return None, "β Please select a valid voice" | |
start_time = time.time() | |
# Generate audio using voice ID | |
sample_rate, audio_data = generate_tts_audio( | |
text, voice_id, exag, temp, int(seed_val), cfg | |
) | |
generation_time = time.time() - start_time | |
duration = len(audio_data) / sample_rate | |
voice_name = voice_library[voice_id]["name"] | |
voice_type = voice_library[voice_id]["type"] | |
if CHATTERBOX_AVAILABLE: | |
status = f"""β Real ChatterboxTTS synthesis completed! | |
π Voice: {voice_name} ({voice_type}) | |
β±οΈ Generation time: {generation_time:.2f}s | |
π΅ Audio duration: {duration:.2f}s | |
π Sample rate: {sample_rate} Hz | |
π Audio samples: {len(audio_data):,} | |
""" | |
else: | |
status = f"""β οΈ Fallback audio generated (beep sound) | |
π¨ This is NOT real speech synthesis! | |
π Voice: {voice_name} ({voice_type}) | |
π¦ Upload ChatterboxTTS package for real synthesis | |
β±οΈ Generation time: {generation_time:.2f}s | |
π΅ Audio duration: {duration:.2f}s | |
π‘ To fix: Upload your ChatterboxTTS files to this Space | |
""" | |
return (sample_rate, audio_data), status | |
except Exception as e: | |
logger.error(f"UI generation failed: {e}") | |
return None, f"β Generation failed: {str(e)}" | |
def delete_voice_ui(voice_id): | |
"""Delete voice from UI""" | |
try: | |
if not voice_id or voice_id not in voice_library: | |
return "β Please select a voice to delete", gr.update() | |
voice_info = voice_library[voice_id] | |
if voice_info.get("type") == "builtin": | |
return "β Cannot delete builtin voices", gr.update() | |
voice_name = voice_info["name"] | |
# Delete legacy audio file if it exists | |
if "audio_path" in voice_info and os.path.exists(voice_info["audio_path"]): | |
os.unlink(voice_info["audio_path"]) | |
# Remove from library | |
del voice_library[voice_id] | |
save_voice_library() | |
updated_choices = get_voice_choices() | |
logger.info(f"β UI: Deleted voice {voice_name} ({voice_id})") | |
return ( | |
f"β Voice '{voice_name}' deleted successfully", | |
gr.update(choices=updated_choices, value=updated_choices[0][1] if updated_choices else None) | |
) | |
except Exception as e: | |
logger.error(f"UI voice deletion failed: {e}") | |
return f"β Voice deletion failed: {str(e)}", gr.update() | |
with gr.Blocks(title="ChatterboxTTS Voice Manager", theme=gr.themes.Soft()) as demo: | |
# Status indicator at the top | |
if CHATTERBOX_AVAILABLE: | |
status_color = "green" | |
status_message = "β Real ChatterboxTTS Loaded - Production Ready!" | |
else: | |
status_color = "orange" | |
status_message = "β οΈ Fallback Mode - Upload ChatterboxTTS Package for Real Synthesis" | |
gr.HTML(f""" | |
<div style="background-color: {status_color}; color: white; padding: 10px; border-radius: 5px; margin-bottom: 20px;"> | |
<h3 style="margin: 0;">{status_message}</h3> | |
</div> | |
""") | |
gr.Markdown(""" | |
# π΅ ChatterboxTTS Voice Manager | |
**Advanced text-to-speech with custom voice cloning and voice library management** | |
""") | |
with gr.Tabs(): | |
# Text-to-Speech Tab | |
with gr.TabItem("π΅ Generate Speech"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
value="Hello, this is ChatterboxTTS with custom voice cloning. I can speak in any voice you train me with!", | |
label="Text to synthesize (max 300 characters)", | |
max_lines=5, | |
placeholder="Enter your text here..." | |
) | |
voice_selector = gr.Dropdown( | |
label="π Select Voice (π§ = builtin, π = custom)", | |
choices=get_voice_choices(), | |
value=list(voice_library.keys())[0] if voice_library else None, | |
interactive=True | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("π΅ Generate Speech", variant="primary") | |
refresh_voices_btn = gr.Button("π Refresh Voices", size="sm") | |
with gr.Row(): | |
exaggeration = gr.Slider( | |
0.25, 2, | |
step=0.05, | |
label="Exaggeration (Controls expressiveness - 0.5 = neutral)", | |
value=0.5 | |
) | |
cfg_weight = gr.Slider( | |
0.2, 1, | |
step=0.05, | |
label="CFG Weight (Controls pace and clarity)", | |
value=0.5 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature = gr.Slider( | |
0.05, 5, | |
step=0.05, | |
label="Temperature (Controls randomness)", | |
value=0.8 | |
) | |
seed = gr.Number( | |
value=0, | |
label="Seed (0 = random, set to non-zero for reproducible results)" | |
) | |
with gr.Column(): | |
audio_output = gr.Audio(label="π Generated Speech") | |
status_text = gr.Textbox( | |
label="π Generation Status", | |
interactive=False, | |
lines=8, | |
placeholder="Select a voice and click 'Generate Speech' to start..." | |
) | |
# Voice Management Tab | |
with gr.TabItem("π Voice Library"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### π Available Voices") | |
voices_display = gr.HTML( | |
value=f""" | |
<div style="max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px;"> | |
{''.join([f"<p><strong>{voice_info['name']}</strong> ({voice_info['type']})<br><small>{voice_info['description']}</small></p>" for voice_info in voice_library.values()])} | |
</div> | |
""" | |
) | |
gr.Markdown("### ποΈ Delete Voice") | |
delete_voice_selector = gr.Dropdown( | |
label="Select voice to delete", | |
choices=[(f"{info['name']} ({info['type']})", vid) for vid, info in voice_library.items() if info['type'] == 'custom'], | |
value=None | |
) | |
delete_voice_btn = gr.Button("ποΈ Delete Selected Voice", variant="stop") | |
delete_status = gr.Textbox(label="Delete Status", interactive=False) | |
with gr.Column(): | |
gr.Markdown("### β Create New Voice") | |
new_voice_name = gr.Textbox( | |
label="Voice Name", | |
placeholder="e.g., 'John's Voice', 'Narrator Voice'", | |
value="" | |
) | |
new_voice_description = gr.Textbox( | |
label="Voice Description", | |
placeholder="e.g., 'Professional male voice', 'Warm female narrator'", | |
value="" | |
) | |
new_voice_audio = gr.Audio( | |
label="Upload Voice Sample (5-30 seconds of clear speech)", | |
type="numpy" | |
) | |
create_voice_btn = gr.Button("π― Create Voice", variant="primary") | |
create_status = gr.Textbox( | |
label="π Creation Status", | |
interactive=False, | |
lines=6 | |
) | |
# Voice Library Info Tab | |
with gr.TabItem("π Voice Guide"): | |
gr.Markdown(f""" | |
## π Voice Library Management | |
### π Current Library Status | |
- **Total Voices**: {len(voice_library)} | |
- **Builtin Voices**: {len([v for v in voice_library.values() if v['type'] == 'builtin'])} | |
- **Custom Voices**: {len([v for v in voice_library.values() if v['type'] == 'custom'])} | |
### π§ Builtin Voices | |
These are pre-configured voices that come with the system: | |
{chr(10).join([f"- **{voice_info['name']}**: {voice_info['description']}" for voice_info in voice_library.values() if voice_info['type'] == 'builtin'])} | |
### π― Creating Custom Voices | |
#### π Best Practices: | |
1. **Audio Quality**: Use clear, noise-free recordings | |
2. **Duration**: 5-30 seconds of natural speech | |
3. **Content**: Normal conversational speech works best | |
4. **Format**: WAV, MP3, or FLAC files supported | |
5. **Voice Consistency**: Use the same speaker throughout | |
#### π€ Recording Tips: | |
- Record in a quiet environment | |
- Speak naturally and clearly | |
- Avoid background noise | |
- Use a decent microphone if possible | |
- Read a paragraph of normal text | |
#### π Voice Management: | |
- **Create**: Upload audio + provide name and description | |
- **Use**: Select from dropdown in speech generation | |
- **Delete**: Remove custom voices you no longer need | |
- **Persistent**: Custom voices are saved permanently | |
### π Usage Workflow: | |
1. **Upload Voice Sample** β Create custom voice | |
2. **Select Voice** β Choose from library | |
3. **Generate Speech** β Use selected voice for TTS | |
4. **Manage Library** β Add, delete, organize voices | |
### π API Integration: | |
```python | |
# List voices | |
GET /api/voices | |
# Create voice | |
POST /api/voices/create | |
# Generate speech with voice | |
POST /api/tts/synthesize | |
{{ | |
"text": "Hello world", | |
"voice_id": "your_voice_id" | |
}} | |
# Delete voice | |
DELETE /api/voices/voice_id | |
``` | |
### π‘ Pro Tips: | |
- **Voice Naming**: Use descriptive names like "John_Professional" or "Sarah_Narrator" | |
- **Voice Testing**: Generate short test phrases after creating voices | |
- **Voice Backup**: Custom voices are saved to disk automatically | |
- **Voice Sharing**: Voice IDs can be shared via API | |
""") | |
# Event handlers | |
generate_btn.click( | |
fn=generate_speech_ui, | |
inputs=[text_input, voice_selector, exaggeration, temperature, seed, cfg_weight], | |
outputs=[audio_output, status_text] | |
) | |
refresh_voices_btn.click( | |
fn=refresh_voice_choices, | |
outputs=[voice_selector] | |
) | |
create_voice_btn.click( | |
fn=create_voice_ui, | |
inputs=[new_voice_name, new_voice_description, new_voice_audio], | |
outputs=[create_status, voice_selector] | |
) | |
delete_voice_btn.click( | |
fn=delete_voice_ui, | |
inputs=[delete_voice_selector], | |
outputs=[delete_status, voice_selector] | |
) | |
# System info with voice library status | |
model_status = "β Real ChatterboxTTS" if CHATTERBOX_AVAILABLE else "β οΈ Fallback Model (Beep Sounds)" | |
chatterbox_status = "Available" if CHATTERBOX_AVAILABLE else "Missing - Upload Package" | |
gr.Markdown(f""" | |
### π System Status | |
- **Model**: {model_status} | |
- **Device**: {DEVICE} | |
- **ChatterboxTTS**: {chatterbox_status} | |
- **Voice Library**: {len(voice_library)} voices loaded | |
- **Storage**: {"β Persistent (/data)" if VOICES_DIR.startswith("/data") else "β οΈ Temporary"} | |
- **Generated Files**: {len(audio_cache)} | |
{'''### π Production Ready! | |
Your ChatterboxTTS model is loaded with persistent voice management.''' if CHATTERBOX_AVAILABLE else '''### β οΈ Action Required | |
**You're hearing beep sounds because ChatterboxTTS isn't loaded.** | |
Voice management is working with persistent storage, but you need ChatterboxTTS for real synthesis.'''} | |
""") | |
return demo | |
# Main execution | |
if __name__ == "__main__": | |
logger.info("π Starting ChatterboxTTS Voice Management Service...") | |
# Model status | |
if CHATTERBOX_AVAILABLE and MODEL: | |
model_status = "β Real ChatterboxTTS Loaded" | |
elif MODEL: | |
model_status = "β οΈ Fallback Model (Upload ChatterboxTTS package for real synthesis)" | |
else: | |
model_status = "β No Model Loaded" | |
logger.info(f"Model Status: {model_status}") | |
logger.info(f"Device: {DEVICE}") | |
logger.info(f"ChatterboxTTS Available: {CHATTERBOX_AVAILABLE}") | |
logger.info(f"Voice Library: {len(voice_library)} voices loaded") | |
logger.info(f"Custom Voices: {len([v for v in voice_library.values() if v['type'] == 'custom'])}") | |
if not CHATTERBOX_AVAILABLE: | |
logger.warning("π¨ IMPORTANT: Upload your ChatterboxTTS package to enable real synthesis!") | |
# Always start FastAPI in background (both local and HF Spaces) | |
import uvicorn | |
import threading | |
def run_fastapi(): | |
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") | |
# Start FastAPI in background thread | |
api_thread = threading.Thread(target=run_fastapi, daemon=True) | |
api_thread.start() | |
logger.info("π FastAPI Server: Starting on port 8000") | |
logger.info("π API Documentation will be available") | |
logger.info("π API Endpoints:") | |
logger.info(" - GET /api/voices") | |
logger.info(" - POST /api/voices/create") | |
logger.info(" - DELETE /api/voices/{voice_id}") | |
logger.info(" - POST /api/tts/synthesize") | |
logger.info(" - GET /api/audio/{audio_id}") | |
if os.getenv("SPACE_ID"): | |
# Running in Hugging Face Spaces | |
logger.info("π Running in Hugging Face Spaces") | |
logger.info("π API will be available at: https://[your-space-name].hf.space:8000") | |
logger.info("π API Docs will be at: https://[your-space-name].hf.space:8000/docs") | |
demo = create_gradio_interface() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |
else: | |
# Local development | |
logger.info("π» Running in Local Development") | |
logger.info("π FastAPI: http://localhost:8000") | |
logger.info("π API Docs: http://localhost:8000/docs") | |
logger.info("π΅ Gradio UI: http://localhost:7861") | |
# Start Gradio | |
demo = create_gradio_interface() | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7861 | |
) |