import asyncio import base64 import json import os import pathlib from typing import AsyncGenerator, Literal import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse from fastrtc import ( AsyncStreamHandler, Stream, get_cloudflare_turn_credentials_async, wait_for_item, ) from google import genai from google.genai.types import ( LiveConnectConfig, PrebuiltVoiceConfig, SpeechConfig, VoiceConfig, ) from gradio.utils import get_space from pydantic import BaseModel current_dir = pathlib.Path(__file__).parent load_dotenv() def encode_audio(data: np.ndarray) -> str: """Encode Audio data to send to the server""" return base64.b64encode(data.tobytes()).decode("UTF-8") class GeminiHandler(AsyncStreamHandler): """Handler for the Gemini API""" def __init__( self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000, ) -> None: super().__init__( expected_layout, output_sample_rate, input_sample_rate=16000, ) self.input_queue: asyncio.Queue = asyncio.Queue() self.output_queue: asyncio.Queue = asyncio.Queue() self.quit: asyncio.Event = asyncio.Event() def copy(self) -> "GeminiHandler": return GeminiHandler( expected_layout="mono", output_sample_rate=self.output_sample_rate, ) async def start_up(self): if not self.phone_mode: await self.wait_for_args() voice_name = self.latest_args[1] # Only get voice_name, not api_key else: voice_name = "Puck" # Always use environment variable for API key api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY environment variable is not set") client = genai.Client( api_key=api_key, http_options={"api_version": "v1alpha"}, ) config = LiveConnectConfig( response_modalities=["AUDIO"], # type: ignore speech_config=SpeechConfig( voice_config=VoiceConfig( prebuilt_voice_config=PrebuiltVoiceConfig( voice_name=voice_name, ) ) ), ) async with client.aio.live.connect( model="gemini-2.0-flash-exp", config=config ) as session: async for audio in session.start_stream( stream=self.stream(), mime_type="audio/pcm" ): if audio.data: array = np.frombuffer(audio.data, dtype=np.int16) self.output_queue.put_nowait((self.output_sample_rate, array)) async def stream(self) -> AsyncGenerator[bytes, None]: while not self.quit.is_set(): try: audio = await asyncio.wait_for(self.input_queue.get(), 0.1) yield audio except (asyncio.TimeoutError, TimeoutError): pass async def receive(self, frame: tuple[int, np.ndarray]) -> None: _, array = frame array = array.squeeze() audio_message = encode_audio(array) self.input_queue.put_nowait(audio_message) async def emit(self) -> tuple[int, np.ndarray] | None: return await wait_for_item(self.output_queue) def shutdown(self) -> None: self.quit.set() # Check if API key is available at startup if not os.getenv("GEMINI_API_KEY"): raise RuntimeError("GEMINI_API_KEY environment variable is required but not set") stream = Stream( modality="audio", mode="send-receive", handler=GeminiHandler(), rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None, concurrency_limit=5 if get_space() else None, time_limit=90 if get_space() else None, additional_inputs=[ # Removed API key input, only keep voice selection gr.Dropdown( label="Voice", choices=[ "Puck", "Charon", "Kore", "Fenrir", "Aoede", ], value="Puck", ), ], ) class InputData(BaseModel): webrtc_id: str voice_name: str # Removed api_key field app = FastAPI() stream.mount(app) @app.post("/input_hook") async def _(body: InputData): # Verify API key is available if not os.getenv("GEMINI_API_KEY"): raise HTTPException(status_code=500, detail="GEMINI_API_KEY environment variable is not set") stream.set_input(body.webrtc_id, body.voice_name) # Only pass voice_name return {"status": "ok"} @app.get("/") async def index(): # Check if API key is available before serving the page if not os.getenv("GEMINI_API_KEY"): return HTMLResponse( content="""

Configuration Error

GEMINI_API_KEY environment variable is not set.

Please set your Gemini API key as an environment variable and restart the application.

Example: export GEMINI_API_KEY=your_api_key_here

""", status_code=500 ) # Safely handle RTC configuration rtc_config = None if get_space(): try: rtc_config = await get_cloudflare_turn_credentials_async() except Exception as e: print(f"Warning: Failed to get Cloudflare TURN credentials: {e}") rtc_config = None html_content = (current_dir / "index.html").read_text() html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) return HTMLResponse(content=html_content) if __name__ == "__main__": import os if (mode := os.getenv("MODE")) == "UI": stream.ui.launch(server_port=7860) elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)