Aittsg / app.py
Athspi's picture
Update app.py
6aa8d7a verified
raw
history blame
3.13 kB
import os
import requests
import wave
import base64
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from dotenv import load_dotenv
# Load API key
load_dotenv()
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
raise RuntimeError("Missing GEMINI_API_KEY in environment")
BASE_URL = (
"https://generativelanguage.googleapis.com/"
"v1beta/models/gemini-2.5-flash-preview-tts:"
"generateContent"
f"?key={API_KEY}"
)
app = FastAPI(title="Gemini TTS JSON API")
def save_wav(path: str, pcm: bytes, channels=1, rate=24000, width=2):
with wave.open(path, "wb") as wf:
wf.setnchannels(channels)
wf.setsampwidth(width)
wf.setframerate(rate)
wf.writeframes(pcm)
class SingleTTSRequest(BaseModel):
prompt: str
voice_name: str
class MultiTTSRequest(BaseModel):
prompt: str
speaker1: str
voice1: str
speaker2: str
voice2: str
@app.get("/")
def health():
return {"status": "Gemini TTS JSON API up and running!"}
@app.post("/single_tts")
def single_tts(req: SingleTTSRequest):
payload = {
"model": "gemini-2.5-flash-preview-tts",
"contents": [{"parts": [{"text": req.prompt}]}],
"config": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"voiceConfig": {
"prebuiltVoiceConfig": {"voiceName": req.voice_name}
}
}
}
}
resp = requests.post(BASE_URL, json=payload)
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.json())
b64 = resp.json()["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
pcm = base64.b64decode(b64)
out = "single_output.wav"
save_wav(out, pcm)
return FileResponse(out, media_type="audio/wav", filename=out)
@app.post("/multi_tts")
def multi_tts(req: MultiTTSRequest):
payload = {
"model": "gemini-2.5-flash-preview-tts",
"contents": [{"parts": [{"text": req.prompt}]}],
"config": {
"responseModalities": ["AUDIO"],
"speechConfig": {
"multiSpeakerVoiceConfig": {
"speakerVoiceConfigs": [
{
"speaker": req.speaker1,
"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": req.voice1}}
},
{
"speaker": req.speaker2,
"voiceConfig": {"prebuiltVoiceConfig": {"voiceName": req.voice2}}
}
]
}
}
}
}
resp = requests.post(BASE_URL, json=payload)
if resp.status_code != 200:
raise HTTPException(status_code=resp.status_code, detail=resp.json())
b64 = resp.json()["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
pcm = base64.b64decode(b64)
out = "multi_output.wav"
save_wav(out, pcm)
return FileResponse(out, media_type="audio/wav", filename=out)