import asyncio
import base64
import json
from pathlib import Path
import os
import numpy as np
import openai
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
AsyncStreamHandler,
Stream,
get_twilio_turn_credentials,
wait_for_item,
)
from gradio.utils import get_space
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
import httpx
from typing import Optional, List, Dict
import gradio as gr
load_dotenv()
SAMPLE_RATE = 24000
# HTML content embedded as a string
HTML_CONTENT = """
MOUSE 음성 챗
"""
class BraveSearchClient:
"""Brave Search API client"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.search.brave.com/res/v1/web/search"
async def search(self, query: str, count: int = 10) -> List[Dict]:
"""Perform a web search using Brave Search API"""
if not self.api_key:
return []
headers = {
"Accept": "application/json",
"X-Subscription-Token": self.api_key
}
params = {
"q": query,
"count": count,
"lang": "ko"
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(self.base_url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
results = []
if "web" in data and "results" in data["web"]:
for result in data["web"]["results"][:count]:
results.append({
"title": result.get("title", ""),
"url": result.get("url", ""),
"description": result.get("description", "")
})
return results
except Exception as e:
print(f"Brave Search error: {e}")
return []
# Global state for web search settings
web_search_settings = {}
current_webrtc_id = None
current_web_search_enabled = False
class OpenAIHandler(AsyncStreamHandler):
def __init__(self) -> None:
super().__init__(
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.output_queue = asyncio.Queue()
self.search_client = search_client # Use global search client
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
self.webrtc_id = None
self.web_search_enabled = False
self.startup_called = False
def copy(self):
# Create a new instance that inherits settings
new_handler = OpenAIHandler()
new_handler.search_client = self.search_client
new_handler.webrtc_id = current_webrtc_id # Use global variable
new_handler.web_search_enabled = current_web_search_enabled # Use global variable
return new_handler
async def search_web(self, query: str) -> str:
"""Perform web search and return formatted results"""
if not self.search_client or not self.web_search_enabled:
return "웹 검색이 비활성화되어 있습니다."
print(f"Searching web for: {query}")
results = await self.search_client.search(query)
if not results:
return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
# Format search results
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(
f"{i}. {result['title']}\n"
f" URL: {result['url']}\n"
f" {result['description']}\n"
)
return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
async def start_up(self):
"""Connect to realtime API with function calling enabled"""
# Check web search setting before connecting
if self.webrtc_id and self.webrtc_id in web_search_settings:
self.web_search_enabled = web_search_settings[self.webrtc_id]
print(f"Retrieved web_search_enabled={self.web_search_enabled} for webrtc_id={self.webrtc_id}")
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
self.client = openai.AsyncOpenAI()
# Define the web search function
tools = []
if self.web_search_enabled and self.search_client:
tools = [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for information when the user asks questions about current events, news, weather, prices, or topics that require up-to-date information",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
}
},
"required": ["query"]
}
}
}]
print("Web search function added to tools")
async with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
# Update session with tools
session_update = {
"turn_detection": {"type": "server_vad"},
"tools": tools,
"tool_choice": "auto" if tools else "none"
}
# Add instructions to use web search
if self.web_search_enabled:
session_update["instructions"] = (
"You are a helpful assistant. When users ask about current events, "
"news, recent information, weather, or topics that might have changed recently, "
"use the web_search function to find the most up-to-date information. "
"Always search when the user asks about: 날씨(weather), 뉴스(news), 현재(current), "
"최근(recent), 최신(latest), 오늘(today), or any time-sensitive information."
)
await conn.session.update(session=session_update)
self.connection = conn
print(f"Connected with tools: {len(tools)} functions")
self.startup_called = True
async for event in self.connection:
if event.type == "response.audio_transcript.done":
await self.output_queue.put(AdditionalOutputs(event))
elif event.type == "response.audio.delta":
await self.output_queue.put(
(
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
),
)
# Handle function calls
elif event.type == "response.function_call_arguments.start":
print(f"Function call started")
self.function_call_in_progress = True
self.current_function_args = ""
self.current_call_id = getattr(event, 'call_id', None)
elif event.type == "response.function_call_arguments.delta":
if self.function_call_in_progress:
self.current_function_args += event.delta
elif event.type == "response.function_call_arguments.done":
if self.function_call_in_progress:
print(f"Function call done, args: {self.current_function_args}")
try:
args = json.loads(self.current_function_args)
query = args.get("query", "")
# Emit search event to client
await self.output_queue.put(AdditionalOutputs({
"type": "search",
"query": query
}))
# Perform the search
search_results = await self.search_web(query)
print(f"Search results length: {len(search_results)}")
# Send function result back to the model
if self.connection and self.current_call_id:
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": self.current_call_id,
"output": search_results
}
)
await self.connection.response.create()
except Exception as e:
print(f"Function call error: {e}")
finally:
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.connection:
return
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
await self.connection.input_audio_buffer.append(audio=audio_message)
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
return await wait_for_item(self.output_queue)
async def shutdown(self) -> None:
if self.connection:
await self.connection.close()
self.connection = None
# Initialize search client
brave_api_key = os.getenv("BSEARCH_API")
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
# Create chatbot component
chatbot = gr.Chatbot(type="messages")
# Create base handler and stream
handler = OpenAIHandler()
handler.search_client = search_client
stream = Stream(
handler,
mode="send-receive",
modality="audio",
additional_inputs=[chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=update_chatbot,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=90 if get_space() else None,
)
app = FastAPI()
# Mount the stream
stream.mount(app)
# Intercept the POST body before it reaches the stream
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request as StarletteRequest
import io
class WebSearchMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: StarletteRequest, call_next):
if request.url.path == "/webrtc/offer" and request.method == "POST":
# Read the body
body = await request.body()
# Parse it
try:
data = json.loads(body)
webrtc_id = data.get("webrtc_id")
web_search_enabled = data.get("web_search_enabled", False)
print(f"Middleware intercepted - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
# Store settings globally
global current_webrtc_id, current_web_search_enabled
if webrtc_id:
current_webrtc_id = webrtc_id
current_web_search_enabled = web_search_enabled
web_search_settings[webrtc_id] = web_search_enabled
except:
pass
# Recreate the request with the body
request._body = body
response = await call_next(request)
return response
# Add middleware
app.add_middleware(WebSearchMiddleware)
@app.get("/outputs")
async def outputs(webrtc_id: str):
"""Stream outputs including search events"""
async def output_stream():
async for output in stream.output_stream(webrtc_id):
if hasattr(output, 'args') and output.args:
# Check if it's a search event
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
# Regular transcript event
elif hasattr(output.args[0], 'transcript'):
s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
yield f"event: output\ndata: {s}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
@app.get("/")
async def index():
"""Serve the HTML page"""
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import uvicorn
mode = os.getenv("MODE")
if mode == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)