import asyncio
import base64
import json
from pathlib import Path
import os
import numpy as np
import openai
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
AsyncStreamHandler,
Stream,
get_twilio_turn_credentials,
wait_for_item,
)
from gradio.utils import get_space
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
import httpx
from typing import Optional, List, Dict
import gradio as gr
load_dotenv()
SAMPLE_RATE = 24000
# HTML content embedded as a string
HTML_CONTENT = """
MOUSE 음성 챗
"""
class BraveSearchClient:
"""Brave Search API client"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.search.brave.com/res/v1/web/search"
async def search(self, query: str, count: int = 10) -> List[Dict]:
"""Perform a web search using Brave Search API"""
if not self.api_key:
return []
headers = {
"Accept": "application/json",
"X-Subscription-Token": self.api_key
}
params = {
"q": query,
"count": count,
"lang": "ko"
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(self.base_url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
results = []
if "web" in data and "results" in data["web"]:
for result in data["web"]["results"][:count]:
results.append({
"title": result.get("title", ""),
"url": result.get("url", ""),
"description": result.get("description", "")
})
return results
except Exception as e:
print(f"Brave Search error: {e}")
return []
# Initialize search client globally
brave_api_key = os.getenv("BSEARCH_API")
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
# Store web search settings by connection
web_search_settings = {}
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
class OpenAIHandler(AsyncStreamHandler):
def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None:
super().__init__(
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.output_queue = asyncio.Queue()
self.search_client = search_client
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
self.webrtc_id = webrtc_id
self.web_search_enabled = web_search_enabled
print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
def copy(self):
# Get the most recent settings
if web_search_settings:
# Get the most recent webrtc_id
recent_ids = sorted(web_search_settings.keys(),
key=lambda k: web_search_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = web_search_settings[recent_id]
web_search_enabled = settings.get('enabled', False)
print(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}")
return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id)
print(f"Handler.copy() called - creating new handler with default settings")
return OpenAIHandler(web_search_enabled=False)
async def search_web(self, query: str) -> str:
"""Perform web search and return formatted results"""
if not self.search_client or not self.web_search_enabled:
return "웹 검색이 비활성화되어 있습니다."
print(f"Searching web for: {query}")
results = await self.search_client.search(query)
if not results:
return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
# Format search results
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(
f"{i}. {result['title']}\n"
f" URL: {result['url']}\n"
f" {result['description']}\n"
)
return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
async def start_up(self):
"""Connect to realtime API with function calling enabled"""
# First check if we have the most recent settings
if web_search_settings:
recent_ids = sorted(web_search_settings.keys(),
key=lambda k: web_search_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = web_search_settings[recent_id]
self.web_search_enabled = settings.get('enabled', False)
self.webrtc_id = recent_id
print(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}")
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
self.client = openai.AsyncOpenAI()
# Define the web search function
tools = []
instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean."
if self.web_search_enabled and self.search_client:
tools = [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query in Korean or English"
}
},
"required": ["query"]
}
}
}]
print("Web search function added to tools")
instructions = (
"You are a helpful assistant with web search capabilities. "
"IMPORTANT: You MUST use the web_search function for ANY of these topics:\n"
"- Weather (날씨, 기온, 비, 눈)\n"
"- News (뉴스, 소식)\n"
"- Current events (현재, 최근, 오늘, 지금)\n"
"- Prices (가격, 환율, 주가)\n"
"- Sports scores or results\n"
"- Any question about 2024 or 2025\n"
"- Any time-sensitive information\n\n"
"When in doubt, USE web_search. It's better to search and provide accurate information "
"than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
)
async with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
# Update session with tools
session_update = {
"turn_detection": {"type": "server_vad"},
"instructions": instructions,
"tools": tools,
"tool_choice": "auto" if tools else "none"
}
await conn.session.update(session=session_update)
self.connection = conn
print(f"Connected with tools: {len(tools)} functions")
async for event in self.connection:
# Debug logging for function calls
if event.type.startswith("response.function_call"):
print(f"Function event: {event.type}")
if event.type == "response.audio_transcript.done":
await self.output_queue.put(AdditionalOutputs(event))
elif event.type == "response.audio.delta":
await self.output_queue.put(
(
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
),
)
# Handle function calls
elif event.type == "response.function_call_arguments.start":
print(f"Function call started")
self.function_call_in_progress = True
self.current_function_args = ""
self.current_call_id = getattr(event, 'call_id', None)
elif event.type == "response.function_call_arguments.delta":
if self.function_call_in_progress:
self.current_function_args += event.delta
elif event.type == "response.function_call_arguments.done":
if self.function_call_in_progress:
print(f"Function call done, args: {self.current_function_args}")
try:
args = json.loads(self.current_function_args)
query = args.get("query", "")
# Emit search event to client
await self.output_queue.put(AdditionalOutputs({
"type": "search",
"query": query
}))
# Perform the search
search_results = await self.search_web(query)
print(f"Search results length: {len(search_results)}")
# Send function result back to the model
if self.connection and self.current_call_id:
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": self.current_call_id,
"output": search_results
}
)
await self.connection.response.create()
except Exception as e:
print(f"Function call error: {e}")
finally:
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.connection:
return
try:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
await self.connection.input_audio_buffer.append(audio=audio_message)
except Exception as e:
print(f"Error in receive: {e}")
# Connection might be closed, ignore the error
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
return await wait_for_item(self.output_queue)
async def shutdown(self) -> None:
if self.connection:
await self.connection.close()
self.connection = None
# Create initial handler instance
handler = OpenAIHandler(web_search_enabled=False)
# Create components
chatbot = gr.Chatbot(type="messages")
# Create stream with handler instance
stream = Stream(
handler, # Pass instance, not factory
mode="send-receive",
modality="audio",
additional_inputs=[chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=update_chatbot,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=300 if get_space() else None,
)
app = FastAPI()
# Mount stream
stream.mount(app)
# Intercept offer to capture settings
@app.post("/webrtc/offer", include_in_schema=False)
async def custom_offer(request: Request):
"""Intercept offer to capture web search settings"""
body = await request.json()
webrtc_id = body.get("webrtc_id")
web_search_enabled = body.get("web_search_enabled", False)
print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
# Store settings with timestamp
if webrtc_id:
web_search_settings[webrtc_id] = {
'enabled': web_search_enabled,
'timestamp': asyncio.get_event_loop().time()
}
# Remove our custom route temporarily
custom_route = None
for i, route in enumerate(app.routes):
if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer:
custom_route = app.routes.pop(i)
break
# Forward to stream's offer handler
response = await stream.offer(body)
# Re-add our custom route
if custom_route:
app.routes.insert(0, custom_route)
return response
@app.get("/outputs")
async def outputs(webrtc_id: str):
"""Stream outputs including search events"""
async def output_stream():
async for output in stream.output_stream(webrtc_id):
if hasattr(output, 'args') and output.args:
# Check if it's a search event
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
# Regular transcript event
elif hasattr(output.args[0], 'transcript'):
s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
yield f"event: output\ndata: {s}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
@app.get("/")
async def index():
"""Serve the HTML page"""
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import uvicorn
mode = os.getenv("MODE")
if mode == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)