import asyncio
import base64
import json
from pathlib import Path
import os
import numpy as np
import openai
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
AsyncStreamHandler,
Stream,
get_twilio_turn_credentials,
wait_for_item,
)
from gradio.utils import get_space
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
import httpx
from typing import Optional, List, Dict
import gradio as gr
import io
from scipy import signal
import wave
import aiosqlite
from langdetect import detect, LangDetectException
from datetime import datetime
import uuid
load_dotenv()
SAMPLE_RATE = 24000
DB_PATH = "chat_history.db"
# HTML content embedded as a string
HTML_CONTENT = """
Mouth of 'MOUSE'
"""
class BraveSearchClient:
"""Brave Search API client"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.search.brave.com/res/v1/web/search"
async def search(self, query: str, count: int = 10) -> List[Dict]:
"""Perform a web search using Brave Search API"""
if not self.api_key:
return []
headers = {
"Accept": "application/json",
"X-Subscription-Token": self.api_key
}
params = {
"q": query,
"count": count,
"lang": "ko"
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(self.base_url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
results = []
if "web" in data and "results" in data["web"]:
for result in data["web"]["results"][:count]:
results.append({
"title": result.get("title", ""),
"url": result.get("url", ""),
"description": result.get("description", "")
})
return results
except Exception as e:
print(f"Brave Search error: {e}")
return []
# Database helper class
class ChatDatabase:
"""Database manager for chat history"""
@staticmethod
async def init():
"""Initialize database tables"""
async with aiosqlite.connect(DB_PATH) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
summary TEXT
)
""")
await db.execute("""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
detected_language TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES conversations(id)
)
""")
await db.commit()
@staticmethod
async def create_session(session_id: str):
"""Create a new conversation session"""
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"INSERT INTO conversations (id) VALUES (?)",
(session_id,)
)
await db.commit()
@staticmethod
async def save_message(session_id: str, role: str, content: str):
"""Save a message to the database"""
# Detect language
detected_language = None
try:
if content and len(content) > 10: # Only detect for substantial content
detected_language = detect(content)
except LangDetectException:
pass
async with aiosqlite.connect(DB_PATH) as db:
await db.execute(
"""INSERT INTO messages (session_id, role, content, detected_language)
VALUES (?, ?, ?, ?)""",
(session_id, role, content, detected_language)
)
# Update conversation's updated_at timestamp
await db.execute(
"UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(session_id,)
)
# Update conversation summary (use first user message as summary)
if role == "user":
cursor = await db.execute(
"SELECT summary FROM conversations WHERE id = ?",
(session_id,)
)
row = await cursor.fetchone()
if row and not row[0]: # If no summary exists
summary = content[:100] + "..." if len(content) > 100 else content
await db.execute(
"UPDATE conversations SET summary = ? WHERE id = ?",
(summary, session_id)
)
await db.commit()
@staticmethod
async def get_recent_conversations(limit: int = 10):
"""Get recent conversations"""
async with aiosqlite.connect(DB_PATH) as db:
cursor = await db.execute(
"""SELECT id, created_at, summary
FROM conversations
ORDER BY updated_at DESC
LIMIT ?""",
(limit,)
)
rows = await cursor.fetchall()
return [
{
"id": row[0],
"created_at": row[1],
"summary": row[2] or "새 대화"
}
for row in rows
]
@staticmethod
async def get_conversation_messages(session_id: str):
"""Get all messages for a conversation"""
async with aiosqlite.connect(DB_PATH) as db:
cursor = await db.execute(
"""SELECT role, content, detected_language, timestamp
FROM messages
WHERE session_id = ?
ORDER BY timestamp ASC""",
(session_id,)
)
rows = await cursor.fetchall()
return [
{
"role": row[0],
"content": row[1],
"detected_language": row[2],
"timestamp": row[3]
}
for row in rows
]
# Initialize search client globally
brave_api_key = os.getenv("BSEARCH_API")
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
# Store connection settings
connection_settings = {}
# Initialize OpenAI client for text chat
client = openai.AsyncOpenAI()
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
async def process_text_chat(message: str, web_search_enabled: bool, system_prompt: str, session_id: str) -> Dict[str, str]:
"""Process text chat using GPT-4o-mini model"""
try:
messages = [
{"role": "system", "content": system_prompt or "You are a helpful assistant."}
]
# Handle web search if enabled
if web_search_enabled and search_client:
# Check if the message requires web search
search_keywords = ["날씨", "기온", "비", "눈", "뉴스", "소식", "현재", "최근",
"오늘", "지금", "가격", "환율", "주가", "weather", "news",
"current", "today", "price", "2024", "2025"]
should_search = any(keyword in message.lower() for keyword in search_keywords)
if should_search:
# Perform web search
search_results = await search_client.search(message)
if search_results:
search_context = "웹 검색 결과:\n\n"
for i, result in enumerate(search_results[:5], 1):
search_context += f"{i}. {result['title']}\n{result['description']}\n\n"
messages.append({
"role": "system",
"content": "다음 웹 검색 결과를 참고하여 답변하세요:\n\n" + search_context
})
messages.append({"role": "user", "content": message})
# Call GPT-4o-mini
response = await client.chat.completions.create(
model="gpt-4.1-mini",
messages=messages,
temperature=0.7,
max_tokens=2000
)
response_text = response.choices[0].message.content
# Detect language
detected_language = None
try:
if response_text and len(response_text) > 10:
detected_language = detect(response_text)
except:
pass
# Save messages to database
if session_id:
await ChatDatabase.save_message(session_id, "user", message)
await ChatDatabase.save_message(session_id, "assistant", response_text)
return {
"response": response_text,
"detected_language": detected_language
}
except Exception as e:
print(f"Error in text chat: {e}")
return {"error": str(e)}
class OpenAIHandler(AsyncStreamHandler):
def __init__(self, web_search_enabled: bool = False, system_prompt: str = "",
webrtc_id: str = None, session_id: str = None) -> None:
super().__init__(
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.output_queue = asyncio.Queue()
self.search_client = search_client
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
self.webrtc_id = webrtc_id
self.web_search_enabled = web_search_enabled
self.system_prompt = system_prompt
self.session_id = session_id
print(f"[INIT] Handler created with web_search={web_search_enabled}, session_id={session_id}")
def copy(self):
# Get the most recent settings
if connection_settings:
# Get the most recent webrtc_id
recent_ids = sorted(connection_settings.keys(),
key=lambda k: connection_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = connection_settings[recent_id]
# Log the settings being copied
print(f"[COPY] Copying settings from {recent_id}:")
return OpenAIHandler(
web_search_enabled=settings.get('web_search_enabled', False),
system_prompt=settings.get('system_prompt', ''),
webrtc_id=recent_id,
session_id=settings.get('session_id')
)
print(f"[COPY] No settings found, creating default handler")
return OpenAIHandler(web_search_enabled=False)
async def search_web(self, query: str) -> str:
"""Perform web search and return formatted results"""
if not self.search_client or not self.web_search_enabled:
return "웹 검색이 비활성화되어 있습니다."
print(f"Searching web for: {query}")
results = await self.search_client.search(query)
if not results:
return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
# Format search results
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(
f"{i}. {result['title']}\n"
f" URL: {result['url']}\n"
f" {result['description']}\n"
)
return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
async def process_text_message(self, message: str):
"""Process text message from user"""
if self.connection:
await self.connection.conversation.item.create(
item={
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": message}]
}
)
await self.connection.response.create()
async def start_up(self):
"""Connect to realtime API"""
# First check if we have the most recent settings
if connection_settings and self.webrtc_id:
if self.webrtc_id in connection_settings:
settings = connection_settings[self.webrtc_id]
self.web_search_enabled = settings.get('web_search_enabled', False)
self.system_prompt = settings.get('system_prompt', '')
self.session_id = settings.get('session_id')
print(f"[START_UP] Updated settings from storage for {self.webrtc_id}")
self.client = openai.AsyncOpenAI()
# Connect to Realtime API
print(f"[REALTIME API] Connecting...")
# Define the web search function
tools = []
base_instructions = self.system_prompt or "You are a helpful assistant."
if self.web_search_enabled and self.search_client:
tools = [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
}
},
"required": ["query"]
}
}
}]
print("Web search function added to tools")
search_instructions = (
"\n\nYou have web search capabilities. "
"IMPORTANT: You MUST use the web_search function for ANY of these topics:\n"
"- Weather (날씨, 기온, 비, 눈)\n"
"- News (뉴스, 소식)\n"
"- Current events (현재, 최근, 오늘, 지금)\n"
"- Prices (가격, 환율, 주가)\n"
"- Sports scores or results\n"
"- Any question about 2024 or 2025\n"
"- Any time-sensitive information\n\n"
"When in doubt, USE web_search. It's better to search and provide accurate information "
"than to guess or use outdated information."
)
instructions = base_instructions + search_instructions
else:
instructions = base_instructions
async with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
# Update session with tools
session_update = {
"turn_detection": {"type": "server_vad"},
"instructions": instructions,
"tools": tools,
"tool_choice": "auto" if tools else "none",
"temperature": 0.7,
"max_response_output_tokens": 4096,
"modalities": ["text", "audio"],
"voice": "alloy"
}
await conn.session.update(session=session_update)
self.connection = conn
print(f"Connected with tools: {len(tools)} functions")
async for event in self.connection:
# Debug logging for function calls
if event.type.startswith("response.function_call"):
print(f"Function event: {event.type}")
if event.type == "response.audio_transcript.done":
print(f"[RESPONSE] Transcript: {event.transcript[:100]}...")
# Detect language
detected_language = None
try:
if event.transcript and len(event.transcript) > 10:
detected_language = detect(event.transcript)
except:
pass
# Save to database
if self.session_id:
await ChatDatabase.save_message(self.session_id, "assistant", event.transcript)
output_data = {
"event": event,
"detected_language": detected_language
}
await self.output_queue.put(AdditionalOutputs(output_data))
elif event.type == "response.audio.delta":
await self.output_queue.put(
(
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
),
)
# Handle function calls
elif event.type == "response.function_call_arguments.start":
print(f"Function call started")
self.function_call_in_progress = True
self.current_function_args = ""
self.current_call_id = getattr(event, 'call_id', None)
elif event.type == "response.function_call_arguments.delta":
if self.function_call_in_progress:
self.current_function_args += event.delta
elif event.type == "response.function_call_arguments.done":
if self.function_call_in_progress:
print(f"Function call done, args: {self.current_function_args}")
try:
args = json.loads(self.current_function_args)
query = args.get("query", "")
# Emit search event to client
await self.output_queue.put(AdditionalOutputs({
"type": "search",
"query": query
}))
# Perform the search
search_results = await self.search_web(query)
print(f"Search results length: {len(search_results)}")
# Send function result back to the model
if self.connection and self.current_call_id:
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": self.current_call_id,
"output": search_results
}
)
await self.connection.response.create()
except Exception as e:
print(f"Function call error: {e}")
finally:
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.connection:
print(f"[RECEIVE] No connection, skipping")
return
try:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
await self.connection.input_audio_buffer.append(audio=audio_message)
except Exception as e:
print(f"Error in receive: {e}")
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
item = await wait_for_item(self.output_queue)
# Check if it's a dict with text message
if isinstance(item, dict) and item.get('type') == 'text_message':
await self.process_text_message(item['content'])
return None
return item
async def shutdown(self) -> None:
print(f"[SHUTDOWN] Called")
if self.connection:
await self.connection.close()
self.connection = None
print("[REALTIME API] Connection closed")
# Create initial handler instance
handler = OpenAIHandler(web_search_enabled=False)
# Create components
chatbot = gr.Chatbot(type="messages")
# Create stream with handler instance
stream = Stream(
handler, # Pass instance, not factory
mode="send-receive",
modality="audio",
additional_inputs=[chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=update_chatbot,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=300 if get_space() else None,
)
app = FastAPI()
# Mount stream
stream.mount(app)
# Initialize database on startup
@app.on_event("startup")
async def startup_event():
await ChatDatabase.init()
print("Database initialized")
# Intercept offer to capture settings
@app.post("/webrtc/offer", include_in_schema=False)
async def custom_offer(request: Request):
"""Intercept offer to capture settings"""
body = await request.json()
webrtc_id = body.get("webrtc_id")
web_search_enabled = body.get("web_search_enabled", False)
system_prompt = body.get("system_prompt", "")
session_id = body.get("session_id")
print(f"[OFFER] Received offer with webrtc_id: {webrtc_id}")
print(f"[OFFER] web_search_enabled: {web_search_enabled}")
print(f"[OFFER] session_id: {session_id}")
# Store settings with timestamp
if webrtc_id:
connection_settings[webrtc_id] = {
'web_search_enabled': web_search_enabled,
'system_prompt': system_prompt,
'session_id': session_id,
'timestamp': asyncio.get_event_loop().time()
}
print(f"[OFFER] Stored settings for {webrtc_id}:")
print(f"[OFFER] {connection_settings[webrtc_id]}")
# Remove our custom route temporarily
custom_route = None
for i, route in enumerate(app.routes):
if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer:
custom_route = app.routes.pop(i)
break
# Forward to stream's offer handler
print(f"[OFFER] Forwarding to stream.offer()")
response = await stream.offer(body)
# Re-add our custom route
if custom_route:
app.routes.insert(0, custom_route)
print(f"[OFFER] Response status: {response.get('status', 'unknown') if isinstance(response, dict) else 'OK'}")
return response
@app.post("/session/new")
async def create_new_session():
"""Create a new chat session"""
session_id = str(uuid.uuid4())
await ChatDatabase.create_session(session_id)
return {"session_id": session_id}
@app.post("/message/save")
async def save_message(request: Request):
"""Save a message to the database"""
body = await request.json()
session_id = body.get("session_id")
role = body.get("role")
content = body.get("content")
if not all([session_id, role, content]):
return {"error": "Missing required fields"}
await ChatDatabase.save_message(session_id, role, content)
return {"status": "ok"}
@app.get("/history/recent")
async def get_recent_history():
"""Get recent conversation history"""
conversations = await ChatDatabase.get_recent_conversations()
return conversations
@app.get("/history/{session_id}")
async def get_conversation(session_id: str):
"""Get messages for a specific conversation"""
messages = await ChatDatabase.get_conversation_messages(session_id)
return messages
@app.post("/chat/text")
async def chat_text(request: Request):
"""Handle text chat messages using GPT-4o-mini"""
try:
body = await request.json()
message = body.get("message", "")
web_search_enabled = body.get("web_search_enabled", False)
system_prompt = body.get("system_prompt", "")
session_id = body.get("session_id")
if not message:
return {"error": "메시지가 비어있습니다."}
# Process text chat
result = await process_text_chat(message, web_search_enabled, system_prompt, session_id)
return result
except Exception as e:
print(f"Error in chat_text endpoint: {e}")
return {"error": "채팅 처리 중 오류가 발생했습니다."}
@app.post("/text_message/{webrtc_id}")
async def receive_text_message(webrtc_id: str, request: Request):
"""Receive text message from client"""
body = await request.json()
message = body.get("content", "")
# Find the handler for this connection
if webrtc_id in stream.handlers:
handler = stream.handlers[webrtc_id]
# Queue the text message for processing
await handler.output_queue.put({
'type': 'text_message',
'content': message
})
return {"status": "ok"}
@app.get("/outputs")
async def outputs(webrtc_id: str):
"""Stream outputs including search events"""
async def output_stream():
async for output in stream.output_stream(webrtc_id):
if hasattr(output, 'args') and output.args:
# Check if it's a search event
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
# Regular transcript event with language info
elif isinstance(output.args[0], dict) and 'event' in output.args[0]:
event_data = output.args[0]
if 'event' in event_data and hasattr(event_data['event'], 'transcript'):
data = {
"role": "assistant",
"content": event_data['event'].transcript,
"detected_language": event_data.get('detected_language')
}
yield f"event: output\ndata: {json.dumps(data)}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
@app.get("/")
async def index():
"""Serve the HTML page"""
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import uvicorn
mode = os.getenv("MODE")
if mode == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)