File size: 4,479 Bytes
aca07cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import torch
import torchaudio
from speechbrain.pretrained import Tacotron2
from speechbrain.pretrained import HIFIGAN
import io
import numpy as np
import tempfile
import os

app = FastAPI(title="SpeechBrain TTS API", description="Text-to-Speech API using SpeechBrain")

class TTSRequest(BaseModel):
    text: str
    sample_rate: int = 22050

class TTSService:
    def __init__(self):
        self.tacotron2 = None
        self.hifi_gan = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
    def load_models(self):
        if self.tacotron2 is None:
            print("Loading Tacotron2 model...")
            self.tacotron2 = Tacotron2.from_hparams(
                source="speechbrain/tts-tacotron2-ljspeech",
                savedir="tmpdir_tts"
            )
            
        if self.hifi_gan is None:
            print("Loading HiFi-GAN vocoder...")
            self.hifi_gan = HIFIGAN.from_hparams(
                source="speechbrain/tts-hifigan-ljspeech",
                savedir="tmpdir_vocoder"
            )
    
    def synthesize(self, text: str, sample_rate: int = 22050):
        self.load_models()
        
        mel_output, mel_length, alignment = self.tacotron2.encode_text(text)
        
        waveforms = self.hifi_gan.decode_batch(mel_output)
        
        audio_np = waveforms.squeeze().cpu().numpy()
        
        if sample_rate != 22050:
            import librosa
            audio_np = librosa.resample(audio_np, orig_sr=22050, target_sr=sample_rate)
        
        return audio_np, sample_rate

tts_service = TTSService()

@app.get("/")
async def root():
    return {"message": "SpeechBrain TTS API is running!"}

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.post("/synthesize")
async def synthesize_speech(request: TTSRequest):
    try:
        if not request.text or len(request.text.strip()) == 0:
            raise HTTPException(status_code=400, detail="Text cannot be empty")
        
        if len(request.text) > 500:
            raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
        
        audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate)
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
            torchaudio.save(
                tmp_file.name,
                torch.tensor(audio_data).unsqueeze(0),
                sample_rate
            )
            
            with open(tmp_file.name, "rb") as audio_file:
                audio_bytes = audio_file.read()
            
            os.unlink(tmp_file.name)
        
        return StreamingResponse(
            io.BytesIO(audio_bytes),
            media_type="audio/wav",
            headers={"Content-Disposition": "attachment; filename=synthesized_audio.wav"}
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")

@app.post("/synthesize_base64")
async def synthesize_speech_base64(request: TTSRequest):
    import base64
    
    try:
        if not request.text or len(request.text.strip()) == 0:
            raise HTTPException(status_code=400, detail="Text cannot be empty")
        
        if len(request.text) > 500:
            raise HTTPException(status_code=400, detail="Text too long (max 500 characters)")
        
        audio_data, sample_rate = tts_service.synthesize(request.text, request.sample_rate)
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
            torchaudio.save(
                tmp_file.name,
                torch.tensor(audio_data).unsqueeze(0),
                sample_rate
            )
            
            with open(tmp_file.name, "rb") as audio_file:
                audio_bytes = audio_file.read()
            
            os.unlink(tmp_file.name)
        
        audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
        
        return {
            "audio_base64": audio_base64,
            "sample_rate": sample_rate,
            "text": request.text
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)