Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from TTS.api import TTS | |
import os | |
import tempfile | |
import uuid | |
import torch | |
from typing import Optional | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="TTS API", description="Text-to-Speech API using XTTS-v2", version="1.0.0") | |
class TTSRequest(BaseModel): | |
text: str | |
language: str = "en" | |
class TTSService: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Use absolute paths for the model | |
model_path = "XTTS-v2_C3PO/" | |
config_path = "XTTS-v2_C3PO/config.json" | |
# Check if model files exist | |
if not os.path.exists(config_path): | |
logger.warning(f"Custom model config not found at {config_path}") | |
# List contents of model directory for debugging | |
model_dir = "/app/XTTS-v2_C3PO" | |
if os.path.exists(model_dir): | |
logger.info(f"Contents of {model_dir}: {os.listdir(model_dir)}") | |
else: | |
logger.warning(f"Model directory {model_dir} does not exist") | |
# Fallback to default XTTS model | |
logger.info("Falling back to default XTTS model") | |
try: | |
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
logger.info("Default TTS model loaded successfully") | |
return | |
except Exception as e: | |
logger.error(f"Failed to load default TTS model: {e}") | |
raise e | |
try: | |
self.tts = TTS( | |
model_path=model_path, | |
config_path=config_path, | |
progress_bar=False, | |
gpu=torch.cuda.is_available() | |
).to(self.device) | |
logger.info("Custom TTS model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load custom TTS model: {e}") | |
# Fallback to default model | |
logger.info("Falling back to default XTTS model") | |
try: | |
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device) | |
logger.info("Default TTS model loaded successfully") | |
except Exception as fallback_e: | |
logger.error(f"Failed to load default TTS model: {fallback_e}") | |
raise fallback_e | |
def generate_speech(self, text: str, speaker_wav_path: str, language: str = "en") -> str: | |
"""Generate speech and return the path to the output file""" | |
try: | |
# Create a unique filename for the output | |
output_filename = f"output_{uuid.uuid4().hex}.wav" | |
output_path = os.path.join(tempfile.gettempdir(), output_filename) | |
# Generate speech | |
self.tts.tts_to_file( | |
text=text, | |
file_path=output_path, | |
speaker_wav=speaker_wav_path, | |
language=language | |
) | |
return output_path | |
except Exception as e: | |
logger.error(f"Error generating speech: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}") | |
# Initialize TTS service | |
tts_service = TTSService() | |
async def root(): | |
return {"message": "TTS API is running", "status": "healthy"} | |
async def health_check(): | |
return {"status": "healthy", "device": tts_service.device} | |
async def text_to_speech( | |
text: str = Form(...), | |
language: str = Form("en"), | |
speaker_file: UploadFile = File(...) | |
): | |
""" | |
Convert text to speech using a reference speaker voice | |
- **text**: The text to convert to speech | |
- **language**: Language code (default: "en") | |
- **speaker_file**: Audio file containing the reference speaker voice | |
""" | |
if not text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
# Validate file type | |
if not speaker_file.content_type.startswith('audio/'): | |
raise HTTPException(status_code=400, detail="Speaker file must be an audio file") | |
try: | |
# Save uploaded speaker file temporarily | |
speaker_temp_path = "XTTS-v2_C3PO/reference.wav" | |
with open(speaker_temp_path, "wb") as buffer: | |
content = await speaker_file.read() | |
buffer.write(content) | |
# Generate speech | |
output_path = tts_service.generate_speech(text, speaker_temp_path, language) | |
# Return the generated audio file | |
return FileResponse( | |
output_path, | |
media_type="audio/wav", | |
filename=f"tts_output_{uuid.uuid4().hex}.wav", | |
headers={"Content-Disposition": "attachment"} | |
) | |
except Exception as e: | |
# Clean up files in case of error | |
if 'speaker_temp_path' in locals() and os.path.exists(speaker_temp_path): | |
os.remove(speaker_temp_path) | |
logger.error(f"Error in TTS endpoint: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def text_to_speech_with_url(request: TTSRequest, speaker_wav_url: str): | |
""" | |
Convert text to speech using a reference speaker voice from URL | |
- **request**: TTSRequest containing text and language | |
- **speaker_wav_url**: URL to the reference speaker audio file | |
""" | |
if not request.text.strip(): | |
raise HTTPException(status_code=400, detail="Text cannot be empty") | |
try: | |
# For this endpoint, you would need to download the file from URL | |
# This is a simplified version - you might want to add URL validation and download logic | |
raise HTTPException(status_code=501, detail="URL-based speaker input not implemented yet") | |
except Exception as e: | |
logger.error(f"Error in TTS URL endpoint: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) |