tts-api / app.py
Avinyaa
new
f5097fd
raw
history blame
6.26 kB
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from TTS.api import TTS
import os
import tempfile
import uuid
import torch
from typing import Optional
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="TTS API", description="Text-to-Speech API using XTTS-v2", version="1.0.0")
class TTSRequest(BaseModel):
text: str
language: str = "en"
class TTSService:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {self.device}")
# Use absolute paths for the model
model_path = "XTTS-v2_C3PO/"
config_path = "XTTS-v2_C3PO/config.json"
# Check if model files exist
if not os.path.exists(config_path):
logger.warning(f"Custom model config not found at {config_path}")
# List contents of model directory for debugging
model_dir = "/app/XTTS-v2_C3PO"
if os.path.exists(model_dir):
logger.info(f"Contents of {model_dir}: {os.listdir(model_dir)}")
else:
logger.warning(f"Model directory {model_dir} does not exist")
# Fallback to default XTTS model
logger.info("Falling back to default XTTS model")
try:
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
logger.info("Default TTS model loaded successfully")
return
except Exception as e:
logger.error(f"Failed to load default TTS model: {e}")
raise e
try:
self.tts = TTS(
model_path=model_path,
config_path=config_path,
progress_bar=False,
gpu=torch.cuda.is_available()
).to(self.device)
logger.info("Custom TTS model loaded successfully")
except Exception as e:
logger.error(f"Failed to load custom TTS model: {e}")
# Fallback to default model
logger.info("Falling back to default XTTS model")
try:
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(self.device)
logger.info("Default TTS model loaded successfully")
except Exception as fallback_e:
logger.error(f"Failed to load default TTS model: {fallback_e}")
raise fallback_e
def generate_speech(self, text: str, speaker_wav_path: str, language: str = "en") -> str:
"""Generate speech and return the path to the output file"""
try:
# Create a unique filename for the output
output_filename = f"output_{uuid.uuid4().hex}.wav"
output_path = os.path.join(tempfile.gettempdir(), output_filename)
# Generate speech
self.tts.tts_to_file(
text=text,
file_path=output_path,
speaker_wav=speaker_wav_path,
language=language
)
return output_path
except Exception as e:
logger.error(f"Error generating speech: {e}")
raise HTTPException(status_code=500, detail=f"Failed to generate speech: {str(e)}")
# Initialize TTS service
tts_service = TTSService()
@app.get("/")
async def root():
return {"message": "TTS API is running", "status": "healthy"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "device": tts_service.device}
@app.post("/tts")
async def text_to_speech(
text: str = Form(...),
language: str = Form("en"),
speaker_file: UploadFile = File(...)
):
"""
Convert text to speech using a reference speaker voice
- **text**: The text to convert to speech
- **language**: Language code (default: "en")
- **speaker_file**: Audio file containing the reference speaker voice
"""
if not text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Validate file type
if not speaker_file.content_type.startswith('audio/'):
raise HTTPException(status_code=400, detail="Speaker file must be an audio file")
try:
# Save uploaded speaker file temporarily
speaker_temp_path = "XTTS-v2_C3PO/reference.wav"
with open(speaker_temp_path, "wb") as buffer:
content = await speaker_file.read()
buffer.write(content)
# Generate speech
output_path = tts_service.generate_speech(text, speaker_temp_path, language)
# Return the generated audio file
return FileResponse(
output_path,
media_type="audio/wav",
filename=f"tts_output_{uuid.uuid4().hex}.wav",
headers={"Content-Disposition": "attachment"}
)
except Exception as e:
# Clean up files in case of error
if 'speaker_temp_path' in locals() and os.path.exists(speaker_temp_path):
os.remove(speaker_temp_path)
logger.error(f"Error in TTS endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/tts-with-url")
async def text_to_speech_with_url(request: TTSRequest, speaker_wav_url: str):
"""
Convert text to speech using a reference speaker voice from URL
- **request**: TTSRequest containing text and language
- **speaker_wav_url**: URL to the reference speaker audio file
"""
if not request.text.strip():
raise HTTPException(status_code=400, detail="Text cannot be empty")
try:
# For this endpoint, you would need to download the file from URL
# This is a simplified version - you might want to add URL validation and download logic
raise HTTPException(status_code=501, detail="URL-based speaker input not implemented yet")
except Exception as e:
logger.error(f"Error in TTS URL endpoint: {e}")
raise HTTPException(status_code=500, detail=str(e))