|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from pydantic import BaseModel |
|
import torch |
|
import torchaudio |
|
from speechbrain.pretrained import Tacotron2 |
|
from speechbrain.pretrained import HIFIGAN |
|
import io |
|
import numpy as np |
|
import tempfile |
|
import os |
|
|
|
app = FastAPI(title="SpeechBrain TTS API", description="Text-to-Speech API using SpeechBrain") |
|
|
|
class TTSRequest(BaseModel): |
|
text: str |
|
sample_rate: int = 22050 |
|
|
|
class TTSService: |
|
def __init__(self): |
|
self.tacotron2 = None |
|
self.hifi_gan = None |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def load_models(self): |
|
if self.tacotron2 is None: |
|
print("Loading Tacotron2 model...") |
|
self.tacotron2 = Tacotron2.from_hparams( |
|
source="speechbrain/tts-tacotron2-ljspeech", |
|
savedir="tmpdir_tts" |
|
) |
|
|
|
if self.hifi_gan is None: |
|
print("Loading HiFi-GAN vocoder...") |
|
self.hifi_gan = HIFIGAN.from_hparams( |
|
source="speechbrain/tts-hifigan-ljspeech", |
|
savedir="tmpdir_vocoder" |
|
) |
|
|
|
def synthesize(self, text: str, sample_rate: int = 22050): |
|
self.load_models() |
|
|
|
mel_output, mel_length, alignment = self.tacotron2.encode_text(text) |
|
|
|
waveforms = self.hifi_gan.decode_batch(mel_output) |
|
|
|
audio_np = waveforms.squeeze().cpu().numpy() |
|
|
|
if sample_rate != 22050: |
|
import librosa |
|
audio_np = librosa.resample(audio_np, orig_sr=22050, target_sr=sample_rate) |
|
|
|
return audio_np, sample_rate |
|
|
|
tts_service = TTSService() |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "SpeechBrain TTS API is running!"} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy"} |
|
|
|
@app.post("/synthesize") |
|
async def synthesize_speech(request: TTSRequest): |
|
try: |
|
if not request.text or len(request.text.strip()) == 0: |
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
|
if len(request.text) > 500: |
|
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") |
|
|
|
audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate) |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: |
|
torchaudio.save( |
|
tmp_file.name, |
|
torch.tensor(audio_data).unsqueeze(0), |
|
sample_rate |
|
) |
|
|
|
with open(tmp_file.name, "rb") as audio_file: |
|
audio_bytes = audio_file.read() |
|
|
|
os.unlink(tmp_file.name) |
|
|
|
return StreamingResponse( |
|
io.BytesIO(audio_bytes), |
|
media_type="audio/wav", |
|
headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"} |
|
) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") |
|
|
|
@app.post("/synthesize_base64") |
|
async def synthesize_speech_base64(request: TTSRequest): |
|
import base64 |
|
|
|
try: |
|
if not request.text or len(request.text.strip()) == 0: |
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
|
if len(request.text) > 500: |
|
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") |
|
|
|
audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate) |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: |
|
torchaudio.save( |
|
tmp_file.name, |
|
torch.tensor(audio_data).unsqueeze(0), |
|
sample_rate |
|
) |
|
|
|
with open(tmp_file.name, "rb") as audio_file: |
|
audio_bytes = audio_file.read() |
|
|
|
os.unlink(tmp_file.name) |
|
|
|
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') |
|
|
|
return { |
|
"audio_base64": audio_base64, |
|
"sample_rate": sample_rate, |
|
"text": request.text |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |