from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel import torch import torchaudio from speechbrain.pretrained import Tacotron2 from speechbrain.pretrained import HIFIGAN import io import numpy as np import tempfile import os app = FastAPI(title="SpeechBrain TTS API", description="Text-to-Speech API using SpeechBrain") class TTSRequest(BaseModel): text: str sample_rate: int = 22050 class TTSService: def __init__(self): self.tacotron2 = None self.hifi_gan = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_models(self): if self.tacotron2 is None: print("Loading Tacotron2 model...") self.tacotron2 = Tacotron2.from_hparams( source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tts" ) if self.hifi_gan is None: print("Loading HiFi-GAN vocoder...") self.hifi_gan = HIFIGAN.from_hparams( source="speechbrain/tts-hifigan-ljspeech", savedir="tmpdir_vocoder" ) def synthesize(self, text: str, sample_rate: int = 22050): self.load_models() mel_output, mel_length, alignment = self.tacotron2.encode_text(text) waveforms = self.hifi_gan.decode_batch(mel_output) audio_np = waveforms.squeeze().cpu().numpy() if sample_rate != 22050: import librosa audio_np = librosa.resample(audio_np, orig_sr=22050, target_sr=sample_rate) return audio_np, sample_rate tts_service = TTSService() @app.get("/") async def root(): return {"message": "SpeechBrain TTS API is running!"} @app.get("/health") async def health_check(): return {"status": "healthy"} @app.post("/synthesize") async def synthesize_speech(request: TTSRequest): try: if not request.text or len(request.text.strip()) == 0: raise HTTPException(status_code=400, detail="Text cannot be empty") if len(request.text) > 500: raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: torchaudio.save( tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate ) with open(tmp_file.name, "rb") as audio_file: audio_bytes = audio_file.read() os.unlink(tmp_file.name) return StreamingResponse( io.BytesIO(audio_bytes), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"} ) except Exception as e: raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") @app.post("/synthesize_base64") async def synthesize_speech_base64(request: TTSRequest): import base64 try: if not request.text or len(request.text.strip()) == 0: raise HTTPException(status_code=400, detail="Text cannot be empty") if len(request.text) > 500: raise HTTPException(status_code=400, detail="Text too long (max 500 characters)") audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: torchaudio.save( tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate ) with open(tmp_file.name, "rb") as audio_file: audio_bytes = audio_file.read() os.unlink(tmp_file.name) audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') return { "audio_base64": audio_base64, "sample_rate": sample_rate, "text": request.text } except Exception as e: raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)