Athspi commited on
Commit
6aa8d7a
·
verified ·
1 Parent(s): a71d68c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -35
app.py CHANGED
@@ -2,17 +2,17 @@ import os
2
  import requests
3
  import wave
4
  import base64
5
- from fastapi import FastAPI, Form
6
- from fastapi.responses import FileResponse, JSONResponse
 
7
  from dotenv import load_dotenv
8
 
9
  # Load API key
10
  load_dotenv()
11
  API_KEY = os.getenv("GEMINI_API_KEY")
12
  if not API_KEY:
13
- raise ValueError("Missing GEMINI_API_KEY in .env")
14
 
15
- # REST endpoint (API key as query param)
16
  BASE_URL = (
17
  "https://generativelanguage.googleapis.com/"
18
  "v1beta/models/gemini-2.5-flash-preview-tts:"
@@ -20,34 +20,40 @@ BASE_URL = (
20
  f"?key={API_KEY}"
21
  )
22
 
23
- app = FastAPI(title="Gemini TTS Space")
24
 
25
  def save_wav(path: str, pcm: bytes, channels=1, rate=24000, width=2):
26
- """Write raw PCM bytes to a WAV file."""
27
  with wave.open(path, "wb") as wf:
28
  wf.setnchannels(channels)
29
  wf.setsampwidth(width)
30
  wf.setframerate(rate)
31
  wf.writeframes(pcm)
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  @app.get("/")
34
  def health():
35
- return {"status": "Gemini TTS Space is live!"}
36
 
37
  @app.post("/single_tts")
38
- def single_tts(
39
- prompt: str = Form(...),
40
- voice_name: str = Form(...)
41
- ):
42
- # Build payload with all TTS settings under `config`
43
  payload = {
44
  "model": "gemini-2.5-flash-preview-tts",
45
- "contents": [{"parts": [{"text": prompt}]}],
46
  "config": {
47
  "responseModalities": ["AUDIO"],
48
  "speechConfig": {
49
  "voiceConfig": {
50
- "prebuiltVoiceConfig": {"voiceName": voice_name}
51
  }
52
  }
53
  }
@@ -55,40 +61,31 @@ def single_tts(
55
 
56
  resp = requests.post(BASE_URL, json=payload)
57
  if resp.status_code != 200:
58
- return JSONResponse(status_code=resp.status_code, content=resp.json())
59
 
60
- # Decode and save
61
- data_b64 = resp.json()["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
62
- pcm = base64.b64decode(data_b64)
63
  out = "single_output.wav"
64
  save_wav(out, pcm)
65
  return FileResponse(out, media_type="audio/wav", filename=out)
66
 
67
  @app.post("/multi_tts")
68
- def multi_tts(
69
- prompt: str = Form(...),
70
- speaker1: str = Form(...), voice1: str = Form(...),
71
- speaker2: str = Form(...), voice2: str = Form(...)
72
- ):
73
  payload = {
74
  "model": "gemini-2.5-flash-preview-tts",
75
- "contents": [{"parts": [{"text": prompt}]}],
76
  "config": {
77
  "responseModalities": ["AUDIO"],
78
  "speechConfig": {
79
  "multiSpeakerVoiceConfig": {
80
  "speakerVoiceConfigs": [
81
  {
82
- "speaker": speaker1,
83
- "voiceConfig": {
84
- "prebuiltVoiceConfig": {"voiceName": voice1}
85
- }
86
  },
87
  {
88
- "speaker": speaker2,
89
- "voiceConfig": {
90
- "prebuiltVoiceConfig": {"voiceName": voice2}
91
- }
92
  }
93
  ]
94
  }
@@ -98,10 +95,10 @@ def multi_tts(
98
 
99
  resp = requests.post(BASE_URL, json=payload)
100
  if resp.status_code != 200:
101
- return JSONResponse(status_code=resp.status_code, content=resp.json())
102
 
103
- data_b64 = resp.json()["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
104
- pcm = base64.b64decode(data_b64)
105
  out = "multi_output.wav"
106
  save_wav(out, pcm)
107
  return FileResponse(out, media_type="audio/wav", filename=out)
 
2
  import requests
3
  import wave
4
  import base64
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.responses import FileResponse
7
+ from pydantic import BaseModel
8
  from dotenv import load_dotenv
9
 
10
  # Load API key
11
  load_dotenv()
12
  API_KEY = os.getenv("GEMINI_API_KEY")
13
  if not API_KEY:
14
+ raise RuntimeError("Missing GEMINI_API_KEY in environment")
15
 
 
16
  BASE_URL = (
17
  "https://generativelanguage.googleapis.com/"
18
  "v1beta/models/gemini-2.5-flash-preview-tts:"
 
20
  f"?key={API_KEY}"
21
  )
22
 
23
+ app = FastAPI(title="Gemini TTS JSON API")
24
 
25
  def save_wav(path: str, pcm: bytes, channels=1, rate=24000, width=2):
 
26
  with wave.open(path, "wb") as wf:
27
  wf.setnchannels(channels)
28
  wf.setsampwidth(width)
29
  wf.setframerate(rate)
30
  wf.writeframes(pcm)
31
 
32
+ class SingleTTSRequest(BaseModel):
33
+ prompt: str
34
+ voice_name: str
35
+
36
+ class MultiTTSRequest(BaseModel):
37
+ prompt: str
38
+ speaker1: str
39
+ voice1: str
40
+ speaker2: str
41
+ voice2: str
42
+
43
  @app.get("/")
44
  def health():
45
+ return {"status": "Gemini TTS JSON API up and running!"}
46
 
47
  @app.post("/single_tts")
48
+ def single_tts(req: SingleTTSRequest):
 
 
 
 
49
  payload = {
50
  "model": "gemini-2.5-flash-preview-tts",
51
+ "contents": [{"parts": [{"text": req.prompt}]}],
52
  "config": {
53
  "responseModalities": ["AUDIO"],
54
  "speechConfig": {
55
  "voiceConfig": {
56
+ "prebuiltVoiceConfig": {"voiceName": req.voice_name}
57
  }
58
  }
59
  }
 
61
 
62
  resp = requests.post(BASE_URL, json=payload)
63
  if resp.status_code != 200:
64
+ raise HTTPException(status_code=resp.status_code, detail=resp.json())
65
 
66
+ b64 = resp.json()["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
67
+ pcm = base64.b64decode(b64)
 
68
  out = "single_output.wav"
69
  save_wav(out, pcm)
70
  return FileResponse(out, media_type="audio/wav", filename=out)
71
 
72
  @app.post("/multi_tts")
73
+ def multi_tts(req: MultiTTSRequest):
 
 
 
 
74
  payload = {
75
  "model": "gemini-2.5-flash-preview-tts",
76
+ "contents": [{"parts": [{"text": req.prompt}]}],
77
  "config": {
78
  "responseModalities": ["AUDIO"],
79
  "speechConfig": {
80
  "multiSpeakerVoiceConfig": {
81
  "speakerVoiceConfigs": [
82
  {
83
+ "speaker": req.speaker1,
84
+ "voiceConfig": {"prebuiltVoiceConfig": {"voiceName": req.voice1}}
 
 
85
  },
86
  {
87
+ "speaker": req.speaker2,
88
+ "voiceConfig": {"prebuiltVoiceConfig": {"voiceName": req.voice2}}
 
 
89
  }
90
  ]
91
  }
 
95
 
96
  resp = requests.post(BASE_URL, json=payload)
97
  if resp.status_code != 200:
98
+ raise HTTPException(status_code=resp.status_code, detail=resp.json())
99
 
100
+ b64 = resp.json()["candidates"][0]["content"]["parts"][0]["inlineData"]["data"]
101
+ pcm = base64.b64decode(b64)
102
  out = "multi_output.wav"
103
  save_wav(out, pcm)
104
  return FileResponse(out, media_type="audio/wav", filename=out)