import asyncio import base64 import json from pathlib import Path import os import numpy as np import openai from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastrtc import ( AdditionalOutputs, AsyncStreamHandler, Stream, get_twilio_turn_credentials, wait_for_item, ) from gradio.utils import get_space from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent import httpx from typing import Optional, List, Dict import gradio as gr load_dotenv() SAMPLE_RATE = 24000 # HTML content embedded as a string HTML_CONTENT = """ MOUSE 음성 챗
웹 검색
연결 대기 중
""" class BraveSearchClient: """Brave Search API client""" def __init__(self, api_key: str): self.api_key = api_key self.base_url = "https://api.search.brave.com/res/v1/web/search" async def search(self, query: str, count: int = 10) -> List[Dict]: """Perform a web search using Brave Search API""" if not self.api_key: return [] headers = { "Accept": "application/json", "X-Subscription-Token": self.api_key } params = { "q": query, "count": count, "lang": "ko" } async with httpx.AsyncClient() as client: try: response = await client.get(self.base_url, headers=headers, params=params) response.raise_for_status() data = response.json() results = [] if "web" in data and "results" in data["web"]: for result in data["web"]["results"][:count]: results.append({ "title": result.get("title", ""), "url": result.get("url", ""), "description": result.get("description", "") }) return results except Exception as e: print(f"Brave Search error: {e}") return [] # Initialize search client globally brave_api_key = os.getenv("BSEARCH_API") search_client = BraveSearchClient(brave_api_key) if brave_api_key else None print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}") # Store web search settings by connection web_search_settings = {} def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): chatbot.append({"role": "assistant", "content": response.transcript}) return chatbot class OpenAIHandler(AsyncStreamHandler): def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None: super().__init__( expected_layout="mono", output_sample_rate=SAMPLE_RATE, output_frame_size=480, input_sample_rate=SAMPLE_RATE, ) self.connection = None self.output_queue = asyncio.Queue() self.search_client = search_client self.function_call_in_progress = False self.current_function_args = "" self.current_call_id = None self.webrtc_id = webrtc_id self.web_search_enabled = web_search_enabled print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}") def copy(self): # Get the most recent settings if web_search_settings: # Get the most recent webrtc_id recent_ids = sorted(web_search_settings.keys(), key=lambda k: web_search_settings[k].get('timestamp', 0), reverse=True) if recent_ids: recent_id = recent_ids[0] settings = web_search_settings[recent_id] web_search_enabled = settings.get('enabled', False) print(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}") return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id) print(f"Handler.copy() called - creating new handler with default settings") return OpenAIHandler(web_search_enabled=False) async def search_web(self, query: str) -> str: """Perform web search and return formatted results""" if not self.search_client or not self.web_search_enabled: return "웹 검색이 비활성화되어 있습니다." print(f"Searching web for: {query}") results = await self.search_client.search(query) if not results: return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다." # Format search results formatted_results = [] for i, result in enumerate(results, 1): formatted_results.append( f"{i}. {result['title']}\n" f" URL: {result['url']}\n" f" {result['description']}\n" ) return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results) async def start_up(self): """Connect to realtime API with function calling enabled""" # First check if we have the most recent settings if web_search_settings: recent_ids = sorted(web_search_settings.keys(), key=lambda k: web_search_settings[k].get('timestamp', 0), reverse=True) if recent_ids: recent_id = recent_ids[0] settings = web_search_settings[recent_id] self.web_search_enabled = settings.get('enabled', False) self.webrtc_id = recent_id print(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}") print(f"Starting up handler with web_search_enabled={self.web_search_enabled}") self.client = openai.AsyncOpenAI() # Define the web search function tools = [] instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean." if self.web_search_enabled and self.search_client: tools = [{ "type": "function", "function": { "name": "web_search", "description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "The search query in Korean or English" } }, "required": ["query"] } } }] print("Web search function added to tools") instructions = ( "You are a helpful assistant with web search capabilities. " "IMPORTANT: You MUST use the web_search function for ANY of these topics:\n" "- Weather (날씨, 기온, 비, 눈)\n" "- News (뉴스, 소식)\n" "- Current events (현재, 최근, 오늘, 지금)\n" "- Prices (가격, 환율, 주가)\n" "- Sports scores or results\n" "- Any question about 2024 or 2025\n" "- Any time-sensitive information\n\n" "When in doubt, USE web_search. It's better to search and provide accurate information " "than to guess or use outdated information. Always respond in Korean when the user speaks Korean." ) async with self.client.beta.realtime.connect( model="gpt-4o-mini-realtime-preview-2024-12-17" ) as conn: # Update session with tools session_update = { "turn_detection": {"type": "server_vad"}, "instructions": instructions, "tools": tools, "tool_choice": "auto" if tools else "none" } await conn.session.update(session=session_update) self.connection = conn print(f"Connected with tools: {len(tools)} functions") async for event in self.connection: # Debug logging for function calls if event.type.startswith("response.function_call"): print(f"Function event: {event.type}") if event.type == "response.audio_transcript.done": await self.output_queue.put(AdditionalOutputs(event)) elif event.type == "response.audio.delta": await self.output_queue.put( ( self.output_sample_rate, np.frombuffer( base64.b64decode(event.delta), dtype=np.int16 ).reshape(1, -1), ), ) # Handle function calls elif event.type == "response.function_call_arguments.start": print(f"Function call started") self.function_call_in_progress = True self.current_function_args = "" self.current_call_id = getattr(event, 'call_id', None) elif event.type == "response.function_call_arguments.delta": if self.function_call_in_progress: self.current_function_args += event.delta elif event.type == "response.function_call_arguments.done": if self.function_call_in_progress: print(f"Function call done, args: {self.current_function_args}") try: args = json.loads(self.current_function_args) query = args.get("query", "") # Emit search event to client await self.output_queue.put(AdditionalOutputs({ "type": "search", "query": query })) # Perform the search search_results = await self.search_web(query) print(f"Search results length: {len(search_results)}") # Send function result back to the model if self.connection and self.current_call_id: await self.connection.conversation.item.create( item={ "type": "function_call_output", "call_id": self.current_call_id, "output": search_results } ) await self.connection.response.create() except Exception as e: print(f"Function call error: {e}") finally: self.function_call_in_progress = False self.current_function_args = "" self.current_call_id = None async def receive(self, frame: tuple[int, np.ndarray]) -> None: if not self.connection: return try: _, array = frame array = array.squeeze() audio_message = base64.b64encode(array.tobytes()).decode("utf-8") await self.connection.input_audio_buffer.append(audio=audio_message) except Exception as e: print(f"Error in receive: {e}") # Connection might be closed, ignore the error async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: return await wait_for_item(self.output_queue) async def shutdown(self) -> None: if self.connection: await self.connection.close() self.connection = None # Create initial handler instance handler = OpenAIHandler(web_search_enabled=False) # Create components chatbot = gr.Chatbot(type="messages") # Create stream with handler instance stream = Stream( handler, # Pass instance, not factory mode="send-receive", modality="audio", additional_inputs=[chatbot], additional_outputs=[chatbot], additional_outputs_handler=update_chatbot, rtc_configuration=get_twilio_turn_credentials() if get_space() else None, concurrency_limit=5 if get_space() else None, time_limit=300 if get_space() else None, ) app = FastAPI() # Mount stream stream.mount(app) # Intercept offer to capture settings @app.post("/webrtc/offer", include_in_schema=False) async def custom_offer(request: Request): """Intercept offer to capture web search settings""" body = await request.json() webrtc_id = body.get("webrtc_id") web_search_enabled = body.get("web_search_enabled", False) print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}") # Store settings with timestamp if webrtc_id: web_search_settings[webrtc_id] = { 'enabled': web_search_enabled, 'timestamp': asyncio.get_event_loop().time() } # Remove our custom route temporarily custom_route = None for i, route in enumerate(app.routes): if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer: custom_route = app.routes.pop(i) break # Forward to stream's offer handler response = await stream.offer(body) # Re-add our custom route if custom_route: app.routes.insert(0, custom_route) return response @app.get("/outputs") async def outputs(webrtc_id: str): """Stream outputs including search events""" async def output_stream(): async for output in stream.output_stream(webrtc_id): if hasattr(output, 'args') and output.args: # Check if it's a search event if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search': yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n" # Regular transcript event elif hasattr(output.args[0], 'transcript'): s = json.dumps({"role": "assistant", "content": output.args[0].transcript}) yield f"event: output\ndata: {s}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") @app.get("/") async def index(): """Serve the HTML page""" rtc_config = get_twilio_turn_credentials() if get_space() else None html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) return HTMLResponse(content=html_content) if __name__ == "__main__": import uvicorn mode = os.getenv("MODE") if mode == "UI": stream.ui.launch(server_port=7860) elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: uvicorn.run(app, host="0.0.0.0", port=7860)