miya3333's picture
Upload 4 files
aca07cd verified
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import torch
import torchaudio
from speechbrain.pretrained import Tacotron2
from speechbrain.pretrained import HIFIGAN
import io
import numpy as np
import tempfile
import os
app = FastAPI(title="SpeechBrain TTS API", description="Text-to-Speech API using SpeechBrain")
class TTSRequest(BaseModel):
text: str
sample_rate: int = 22050
class TTSService:
def __init__(self):
self.tacotron2 = None
self.hifi_gan = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_models(self):
if self.tacotron2 is None:
print("Loading Tacotron2 model...")
self.tacotron2 = Tacotron2.from_hparams(
source="speechbrain/tts-tacotron2-ljspeech",
savedir="tmpdir_tts"
)
if self.hifi_gan is None:
print("Loading HiFi-GAN vocoder...")
self.hifi_gan = HIFIGAN.from_hparams(
source="speechbrain/tts-hifigan-ljspeech",
savedir="tmpdir_vocoder"
)
def synthesize(self, text: str, sample_rate: int = 22050):
self.load_models()
mel_output, mel_length, alignment = self.tacotron2.encode_text(text)
waveforms = self.hifi_gan.decode_batch(mel_output)
audio_np = waveforms.squeeze().cpu().numpy()
if sample_rate != 22050:
import librosa
audio_np = librosa.resample(audio_np, orig_sr=22050, target_sr=sample_rate)
return audio_np, sample_rate
tts_service = TTSService()
@app.get("/")
async def root():
return {"message": "SpeechBrain TTS API is running!"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/synthesize")
async def synthesize_speech(request: TTSRequest):
try:
if not request.text or len(request.text.strip()) == 0:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(request.text) > 500:
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
torchaudio.save(
tmp_file.name,
torch.tensor(audio_data).unsqueeze(0),
sample_rate
)
with open(tmp_file.name, "rb") as audio_file:
audio_bytes = audio_file.read()
os.unlink(tmp_file.name)
return StreamingResponse(
io.BytesIO(audio_bytes),
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
@app.post("/synthesize_base64")
async def synthesize_speech_base64(request: TTSRequest):
import base64
try:
if not request.text or len(request.text.strip()) == 0:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(request.text) > 500:
raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
torchaudio.save(
tmp_file.name,
torch.tensor(audio_data).unsqueeze(0),
sample_rate
)
with open(tmp_file.name, "rb") as audio_file:
audio_bytes = audio_file.read()
os.unlink(tmp_file.name)
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
return {
"audio_base64": audio_base64,
"sample_rate": sample_rate,
"text": request.text
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)