medvisit / app.py
siyah1's picture
Update app.py
266ffce verified
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from fastrtc import (
AsyncStreamHandler,
Stream,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
)
from gradio.utils import get_space
from pydantic import BaseModel
current_dir = pathlib.Path(__file__).parent
load_dotenv()
def encode_audio(data: np.ndarray) -> str:
"""Encode Audio data to send to the server"""
return base64.b64encode(data.tobytes()).decode("UTF-8")
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API"""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
)
async def start_up(self):
if not self.phone_mode:
await self.wait_for_args()
voice_name = self.latest_args[1] # Only get voice_name, not api_key
else:
voice_name = "Puck"
# Always use environment variable for API key
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY environment variable is not set")
client = genai.Client(
api_key=api_key,
http_options={"api_version": "v1alpha"},
)
config = LiveConnectConfig(
response_modalities=["AUDIO"], # type: ignore
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(
voice_name=voice_name,
)
)
),
)
async with client.aio.live.connect(
model="gemini-2.0-flash-exp", config=config
) as session:
async for audio in session.start_stream(
stream=self.stream(), mime_type="audio/pcm"
):
if audio.data:
array = np.frombuffer(audio.data, dtype=np.int16)
self.output_queue.put_nowait((self.output_sample_rate, array))
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
audio = await asyncio.wait_for(self.input_queue.get(), 0.1)
yield audio
except (asyncio.TimeoutError, TimeoutError):
pass
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
self.input_queue.put_nowait(audio_message)
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
# Check if API key is available at startup
if not os.getenv("GEMINI_API_KEY"):
raise RuntimeError("GEMINI_API_KEY environment variable is required but not set")
stream = Stream(
modality="audio",
mode="send-receive",
handler=GeminiHandler(),
rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
additional_inputs=[
# Removed API key input, only keep voice selection
gr.Dropdown(
label="Voice",
choices=[
"Puck",
"Charon",
"Kore",
"Fenrir",
"Aoede",
],
value="Puck",
),
],
)
class InputData(BaseModel):
webrtc_id: str
voice_name: str
# Removed api_key field
app = FastAPI()
stream.mount(app)
@app.post("/input_hook")
async def _(body: InputData):
# Verify API key is available
if not os.getenv("GEMINI_API_KEY"):
raise HTTPException(status_code=500, detail="GEMINI_API_KEY environment variable is not set")
stream.set_input(body.webrtc_id, body.voice_name) # Only pass voice_name
return {"status": "ok"}
@app.get("/")
async def index():
# Check if API key is available before serving the page
if not os.getenv("GEMINI_API_KEY"):
return HTMLResponse(
content="""
<html>
<body style="font-family: Arial, sans-serif; text-align: center; padding: 50px;">
<h1>Configuration Error</h1>
<p>GEMINI_API_KEY environment variable is not set.</p>
<p>Please set your Gemini API key as an environment variable and restart the application.</p>
<p>Example: <code>export GEMINI_API_KEY=your_api_key_here</code></p>
</body>
</html>
""",
status_code=500
)
# Safely handle RTC configuration
rtc_config = None
if get_space():
try:
rtc_config = await get_cloudflare_turn_credentials_async()
except Exception as e:
print(f"Warning: Failed to get Cloudflare TURN credentials: {e}")
rtc_config = None
html_content = (current_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)