Athspi commited on
Commit
7d3c0d1
·
verified ·
1 Parent(s): 515f8f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -40
app.py CHANGED
@@ -1,14 +1,16 @@
1
  from fastapi import FastAPI, HTTPException, Request
2
- from fastapi.responses import FileResponse, JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
  from google import genai
5
  from google.genai import types
6
  import wave
 
7
  import os
8
- import uuid
9
  from typing import Optional
10
  from pydantic import BaseModel
11
- from pathlib import Path
 
 
 
12
 
13
  app = FastAPI(
14
  title="Google GenAI TTS API",
@@ -18,13 +20,6 @@ app = FastAPI(
18
  redoc_url=None
19
  )
20
 
21
- # Configuration
22
- AUDIO_OUTPUT_DIR = "audio_output"
23
- Path(AUDIO_OUTPUT_DIR).mkdir(exist_ok=True)
24
-
25
- # Mount static files for Hugging Face Spaces
26
- app.mount("/static", StaticFiles(directory="static"), name="static")
27
-
28
  class TTSRequest(BaseModel):
29
  text: str
30
  voice_name: Optional[str] = "Kore"
@@ -40,13 +35,15 @@ def initialize_genai_client():
40
  raise ValueError("GEMINI_API_KEY environment variable not set")
41
  return genai.Client(api_key=api_key)
42
 
43
- def generate_wave_file(filename: str, pcm_data: bytes, channels: int, rate: int, sample_width: int):
44
- """Generate a WAV file from PCM data"""
45
- with wave.open(filename, "wb") as wf:
46
- wf.setnchannels(channels)
47
- wf.setsampwidth(sample_width)
48
- wf.setframerate(rate)
49
- wf.writeframes(pcm_data)
 
 
50
 
51
  @app.post("/api/generate-tts/")
52
  async def generate_tts(request: TTSRequest):
@@ -62,7 +59,7 @@ async def generate_tts(request: TTSRequest):
62
  - sample_width: Sample width in bytes (default: 2)
63
 
64
  Returns:
65
- - JSON with file URL or error message
66
  """
67
  try:
68
  client = initialize_genai_client()
@@ -89,26 +86,20 @@ async def generate_tts(request: TTSRequest):
89
 
90
  audio_data = response.candidates[0].content.parts[0].inline_data.data
91
 
92
- file_name = f"tts_{uuid.uuid4().hex}.wav"
93
- file_path = os.path.join(AUDIO_OUTPUT_DIR, file_name)
94
-
95
- generate_wave_file(
96
- file_path,
97
  audio_data,
98
  channels=request.channels,
99
  rate=request.sample_rate,
100
  sample_width=request.sample_width
101
  )
102
 
103
- # For Hugging Face Spaces, we need to return the URL where the file can be accessed
104
- file_url = f"/static/{file_name}"
105
- os.rename(file_path, f"static/{file_name}")
106
-
107
- return JSONResponse({
108
- "status": "success",
109
- "audio_url": file_url,
110
- "filename": file_name
111
- })
112
 
113
  except Exception as e:
114
  return JSONResponse(
@@ -120,13 +111,9 @@ async def generate_tts(request: TTSRequest):
120
  async def root():
121
  return {"message": "Google GenAI TTS API is running"}
122
 
123
- # Error handler
124
- @app.exception_handler(Exception)
125
- async def generic_exception_handler(request: Request, exc: Exception):
126
- return JSONResponse(
127
- status_code=500,
128
- content={"message": f"An error occurred: {str(exc)}"}
129
- )
130
 
131
  if __name__ == "__main__":
132
  import uvicorn
 
1
  from fastapi import FastAPI, HTTPException, Request
2
+ from fastapi.responses import JSONResponse, StreamingResponse
 
3
  from google import genai
4
  from google.genai import types
5
  import wave
6
+ import io
7
  import os
 
8
  from typing import Optional
9
  from pydantic import BaseModel
10
+ from dotenv import load_dotenv
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
 
15
  app = FastAPI(
16
  title="Google GenAI TTS API",
 
20
  redoc_url=None
21
  )
22
 
 
 
 
 
 
 
 
23
  class TTSRequest(BaseModel):
24
  text: str
25
  voice_name: Optional[str] = "Kore"
 
35
  raise ValueError("GEMINI_API_KEY environment variable not set")
36
  return genai.Client(api_key=api_key)
37
 
38
+ def generate_wave_bytes(pcm_data: bytes, channels: int, rate: int, sample_width: int) -> bytes:
39
+ """Generate WAV file bytes from PCM data"""
40
+ with io.BytesIO() as wav_buffer:
41
+ with wave.open(wav_buffer, "wb") as wf:
42
+ wf.setnchannels(channels)
43
+ wf.setsampwidth(sample_width)
44
+ wf.setframerate(rate)
45
+ wf.writeframes(pcm_data)
46
+ return wav_buffer.getvalue()
47
 
48
  @app.post("/api/generate-tts/")
49
  async def generate_tts(request: TTSRequest):
 
59
  - sample_width: Sample width in bytes (default: 2)
60
 
61
  Returns:
62
+ - StreamingResponse with the WAV audio file
63
  """
64
  try:
65
  client = initialize_genai_client()
 
86
 
87
  audio_data = response.candidates[0].content.parts[0].inline_data.data
88
 
89
+ wav_bytes = generate_wave_bytes(
 
 
 
 
90
  audio_data,
91
  channels=request.channels,
92
  rate=request.sample_rate,
93
  sample_width=request.sample_width
94
  )
95
 
96
+ return StreamingResponse(
97
+ io.BytesIO(wav_bytes),
98
+ media_type="audio/wav",
99
+ headers={
100
+ "Content-Disposition": f"attachment; filename=generated_audio.wav"
101
+ }
102
+ )
 
 
103
 
104
  except Exception as e:
105
  return JSONResponse(
 
111
  async def root():
112
  return {"message": "Google GenAI TTS API is running"}
113
 
114
+ @app.get("/health")
115
+ async def health_check():
116
+ return {"status": "healthy"}
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
  import uvicorn