import asyncio import base64 import json from pathlib import Path import os import numpy as np import openai from dotenv import load_dotenv from fastapi import FastAPI, Request from fastapi.responses import HTMLResponse, StreamingResponse from fastrtc import ( AdditionalOutputs, AsyncStreamHandler, Stream, get_twilio_turn_credentials, wait_for_item, ) from gradio.utils import get_space from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent import httpx from typing import Optional, List, Dict import gradio as gr load_dotenv() SAMPLE_RATE = 24000 # HTML content embedded as a string HTML_CONTENT = """ MOUSE 음성 챗
웹 검색
연결 대기 중
""" class BraveSearchClient: """Brave Search API client""" def __init__(self, api_key: str): self.api_key = api_key self.base_url = "https://api.search.brave.com/res/v1/web/search" async def search(self, query: str, count: int = 10) -> List[Dict]: """Perform a web search using Brave Search API""" if not self.api_key: return [] headers = { "Accept": "application/json", "X-Subscription-Token": self.api_key } params = { "q": query, "count": count, "lang": "ko" } async with httpx.AsyncClient() as client: try: response = await client.get(self.base_url, headers=headers, params=params) response.raise_for_status() data = response.json() results = [] if "web" in data and "results" in data["web"]: for result in data["web"]["results"][:count]: results.append({ "title": result.get("title", ""), "url": result.get("url", ""), "description": result.get("description", "") }) return results except Exception as e: print(f"Brave Search error: {e}") return [] # Global state for web search settings web_search_settings = {} current_webrtc_id = None current_web_search_enabled = False class OpenAIHandler(AsyncStreamHandler): def __init__(self) -> None: super().__init__( expected_layout="mono", output_sample_rate=SAMPLE_RATE, output_frame_size=480, input_sample_rate=SAMPLE_RATE, ) self.connection = None self.output_queue = asyncio.Queue() self.search_client = search_client # Use global search client self.function_call_in_progress = False self.current_function_args = "" self.current_call_id = None self.webrtc_id = None self.web_search_enabled = False self.startup_called = False def copy(self): # Create a new instance that inherits settings new_handler = OpenAIHandler() new_handler.search_client = self.search_client new_handler.webrtc_id = current_webrtc_id # Use global variable new_handler.web_search_enabled = current_web_search_enabled # Use global variable return new_handler async def search_web(self, query: str) -> str: """Perform web search and return formatted results""" if not self.search_client or not self.web_search_enabled: return "웹 검색이 비활성화되어 있습니다." print(f"Searching web for: {query}") results = await self.search_client.search(query) if not results: return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다." # Format search results formatted_results = [] for i, result in enumerate(results, 1): formatted_results.append( f"{i}. {result['title']}\n" f" URL: {result['url']}\n" f" {result['description']}\n" ) return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results) async def start_up(self): """Connect to realtime API with function calling enabled""" # Check web search setting before connecting if self.webrtc_id and self.webrtc_id in web_search_settings: self.web_search_enabled = web_search_settings[self.webrtc_id] print(f"Retrieved web_search_enabled={self.web_search_enabled} for webrtc_id={self.webrtc_id}") print(f"Starting up handler with web_search_enabled={self.web_search_enabled}") self.client = openai.AsyncOpenAI() # Define the web search function tools = [] if self.web_search_enabled and self.search_client: tools = [{ "type": "function", "function": { "name": "web_search", "description": "Search the web for information when the user asks questions about current events, news, weather, prices, or topics that require up-to-date information", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "The search query" } }, "required": ["query"] } } }] print("Web search function added to tools") async with self.client.beta.realtime.connect( model="gpt-4o-mini-realtime-preview-2024-12-17" ) as conn: # Update session with tools session_update = { "turn_detection": {"type": "server_vad"}, "tools": tools, "tool_choice": "auto" if tools else "none" } # Add instructions to use web search if self.web_search_enabled: session_update["instructions"] = ( "You are a helpful assistant. When users ask about current events, " "news, recent information, weather, or topics that might have changed recently, " "use the web_search function to find the most up-to-date information. " "Always search when the user asks about: 날씨(weather), 뉴스(news), 현재(current), " "최근(recent), 최신(latest), 오늘(today), or any time-sensitive information." ) await conn.session.update(session=session_update) self.connection = conn print(f"Connected with tools: {len(tools)} functions") self.startup_called = True async for event in self.connection: if event.type == "response.audio_transcript.done": await self.output_queue.put(AdditionalOutputs(event)) elif event.type == "response.audio.delta": await self.output_queue.put( ( self.output_sample_rate, np.frombuffer( base64.b64decode(event.delta), dtype=np.int16 ).reshape(1, -1), ), ) # Handle function calls elif event.type == "response.function_call_arguments.start": print(f"Function call started") self.function_call_in_progress = True self.current_function_args = "" self.current_call_id = getattr(event, 'call_id', None) elif event.type == "response.function_call_arguments.delta": if self.function_call_in_progress: self.current_function_args += event.delta elif event.type == "response.function_call_arguments.done": if self.function_call_in_progress: print(f"Function call done, args: {self.current_function_args}") try: args = json.loads(self.current_function_args) query = args.get("query", "") # Emit search event to client await self.output_queue.put(AdditionalOutputs({ "type": "search", "query": query })) # Perform the search search_results = await self.search_web(query) print(f"Search results length: {len(search_results)}") # Send function result back to the model if self.connection and self.current_call_id: await self.connection.conversation.item.create( item={ "type": "function_call_output", "call_id": self.current_call_id, "output": search_results } ) await self.connection.response.create() except Exception as e: print(f"Function call error: {e}") finally: self.function_call_in_progress = False self.current_function_args = "" self.current_call_id = None async def receive(self, frame: tuple[int, np.ndarray]) -> None: if not self.connection: return _, array = frame array = array.squeeze() audio_message = base64.b64encode(array.tobytes()).decode("utf-8") await self.connection.input_audio_buffer.append(audio=audio_message) async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: return await wait_for_item(self.output_queue) async def shutdown(self) -> None: if self.connection: await self.connection.close() self.connection = None # Initialize search client brave_api_key = os.getenv("BSEARCH_API") search_client = BraveSearchClient(brave_api_key) if brave_api_key else None print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}") def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): chatbot.append({"role": "assistant", "content": response.transcript}) return chatbot # Create chatbot component chatbot = gr.Chatbot(type="messages") # Create base handler and stream handler = OpenAIHandler() handler.search_client = search_client stream = Stream( handler, mode="send-receive", modality="audio", additional_inputs=[chatbot], additional_outputs=[chatbot], additional_outputs_handler=update_chatbot, rtc_configuration=get_twilio_turn_credentials() if get_space() else None, concurrency_limit=5 if get_space() else None, time_limit=90 if get_space() else None, ) app = FastAPI() # Mount the stream stream.mount(app) # Intercept the POST body before it reaches the stream from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request as StarletteRequest import io class WebSearchMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: StarletteRequest, call_next): if request.url.path == "/webrtc/offer" and request.method == "POST": # Read the body body = await request.body() # Parse it try: data = json.loads(body) webrtc_id = data.get("webrtc_id") web_search_enabled = data.get("web_search_enabled", False) print(f"Middleware intercepted - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}") # Store settings globally global current_webrtc_id, current_web_search_enabled if webrtc_id: current_webrtc_id = webrtc_id current_web_search_enabled = web_search_enabled web_search_settings[webrtc_id] = web_search_enabled except: pass # Recreate the request with the body request._body = body response = await call_next(request) return response # Add middleware app.add_middleware(WebSearchMiddleware) @app.get("/outputs") async def outputs(webrtc_id: str): """Stream outputs including search events""" async def output_stream(): async for output in stream.output_stream(webrtc_id): if hasattr(output, 'args') and output.args: # Check if it's a search event if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search': yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n" # Regular transcript event elif hasattr(output.args[0], 'transcript'): s = json.dumps({"role": "assistant", "content": output.args[0].transcript}) yield f"event: output\ndata: {s}\n\n" return StreamingResponse(output_stream(), media_type="text/event-stream") @app.get("/") async def index(): """Serve the HTML page""" rtc_config = get_twilio_turn_credentials() if get_space() else None html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) return HTMLResponse(content=html_content) if __name__ == "__main__": import uvicorn mode = os.getenv("MODE") if mode == "UI": stream.ui.launch(server_port=7860) elif mode == "PHONE": stream.fastphone(host="0.0.0.0", port=7860) else: uvicorn.run(app, host="0.0.0.0", port=7860)