Private-AI / app.py
seawolf2357's picture
Update app.py
0d8a2ef verified
raw
history blame
44.9 kB
import asyncio
import base64
import json
from pathlib import Path
import os
import numpy as np
import openai
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
AsyncStreamHandler,
Stream,
get_twilio_turn_credentials,
wait_for_item,
)
from gradio.utils import get_space
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
import httpx
from typing import Optional, List, Dict
import gradio as gr
import logging
from datetime import datetime
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
SAMPLE_RATE = 24000
# HTML content embedded as a string
HTML_CONTENT = """<!DOCTYPE html>
<html lang="ko">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MOUSE ์Œ์„ฑ ์ฑ—</title>
<style>
:root {
--primary-color: #6f42c1;
--secondary-color: #563d7c;
--dark-bg: #121212;
--card-bg: #1e1e1e;
--text-color: #f8f9fa;
--border-color: #333;
--hover-color: #8a5cf6;
}
body {
font-family: "SF Pro Display", -apple-system, BlinkMacSystemFont, sans-serif;
background-color: var(--dark-bg);
color: var(--text-color);
margin: 0;
padding: 0;
height: 100vh;
display: flex;
flex-direction: column;
overflow: hidden;
}
.container {
max-width: 900px;
margin: 0 auto;
padding: 20px;
flex-grow: 1;
display: flex;
flex-direction: column;
width: 100%;
height: calc(100vh - 40px);
box-sizing: border-box;
}
.header {
text-align: center;
padding: 20px 0;
border-bottom: 1px solid var(--border-color);
margin-bottom: 20px;
flex-shrink: 0;
}
.logo {
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
}
.logo h1 {
margin: 0;
background: linear-gradient(135deg, var(--primary-color), #a78bfa);
-webkit-background-clip: text;
background-clip: text;
color: transparent;
font-size: 32px;
letter-spacing: 1px;
}
/* Web search toggle */
.search-toggle {
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
margin-top: 15px;
}
.toggle-switch {
position: relative;
width: 50px;
height: 26px;
background-color: #ccc;
border-radius: 13px;
cursor: pointer;
transition: background-color 0.3s;
}
.toggle-switch.active {
background-color: var(--primary-color);
}
.toggle-slider {
position: absolute;
top: 3px;
left: 3px;
width: 20px;
height: 20px;
background-color: white;
border-radius: 50%;
transition: transform 0.3s;
}
.toggle-switch.active .toggle-slider {
transform: translateX(24px);
}
.search-label {
font-size: 14px;
color: #aaa;
}
.chat-container {
border-radius: 12px;
background-color: var(--card-bg);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
padding: 20px;
flex-grow: 1;
display: flex;
flex-direction: column;
border: 1px solid var(--border-color);
overflow: hidden;
min-height: 0;
}
.chat-messages {
flex-grow: 1;
overflow-y: auto;
padding: 10px;
scrollbar-width: thin;
scrollbar-color: var(--primary-color) var(--card-bg);
min-height: 0;
}
.chat-messages::-webkit-scrollbar {
width: 6px;
}
.chat-messages::-webkit-scrollbar-thumb {
background-color: var(--primary-color);
border-radius: 6px;
}
.message {
margin-bottom: 20px;
padding: 14px;
border-radius: 8px;
font-size: 16px;
line-height: 1.6;
position: relative;
max-width: 80%;
animation: fade-in 0.3s ease-out;
}
@keyframes fade-in {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.message.user {
background: linear-gradient(135deg, #2c3e50, #34495e);
margin-left: auto;
border-bottom-right-radius: 2px;
}
.message.assistant {
background: linear-gradient(135deg, var(--secondary-color), var(--primary-color));
margin-right: auto;
border-bottom-left-radius: 2px;
}
.message.search-result {
background: linear-gradient(135deg, #1a5a3e, #2e7d32);
font-size: 14px;
padding: 10px;
margin-bottom: 10px;
}
.controls {
text-align: center;
margin-top: 20px;
display: flex;
justify-content: center;
flex-shrink: 0;
}
button {
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
color: white;
border: none;
padding: 14px 28px;
font-family: inherit;
font-size: 16px;
cursor: pointer;
transition: all 0.3s;
text-transform: uppercase;
letter-spacing: 1px;
border-radius: 50px;
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
box-shadow: 0 4px 10px rgba(111, 66, 193, 0.3);
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 6px 15px rgba(111, 66, 193, 0.5);
background: linear-gradient(135deg, var(--hover-color), var(--primary-color));
}
button:active {
transform: translateY(1px);
}
#audio-output {
display: none;
}
.icon-with-spinner {
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
min-width: 180px;
}
.spinner {
width: 20px;
height: 20px;
border: 2px solid #ffffff;
border-top-color: transparent;
border-radius: 50%;
animation: spin 1s linear infinite;
flex-shrink: 0;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.audio-visualizer {
display: flex;
align-items: center;
justify-content: center;
gap: 5px;
min-width: 80px;
height: 25px;
}
.visualizer-bar {
width: 4px;
height: 100%;
background-color: rgba(255, 255, 255, 0.7);
border-radius: 2px;
transform-origin: bottom;
transform: scaleY(0.1);
transition: transform 0.1s ease;
}
.toast {
position: fixed;
top: 20px;
left: 50%;
transform: translateX(-50%);
padding: 16px 24px;
border-radius: 8px;
font-size: 14px;
z-index: 1000;
display: none;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
.toast.error {
background-color: #f44336;
color: white;
}
.toast.warning {
background-color: #ff9800;
color: white;
}
.status-indicator {
display: inline-flex;
align-items: center;
margin-top: 10px;
font-size: 14px;
color: #aaa;
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
margin-right: 8px;
}
.status-dot.connected {
background-color: #4caf50;
}
.status-dot.disconnected {
background-color: #f44336;
}
.status-dot.connecting {
background-color: #ff9800;
animation: pulse 1.5s infinite;
}
@keyframes pulse {
0% {
opacity: 0.6;
}
50% {
opacity: 1;
}
100% {
opacity: 0.6;
}
}
.mouse-logo {
position: relative;
width: 40px;
height: 40px;
}
.mouse-ears {
position: absolute;
width: 15px;
height: 15px;
background-color: var(--primary-color);
border-radius: 50%;
}
.mouse-ear-left {
top: 0;
left: 5px;
}
.mouse-ear-right {
top: 0;
right: 5px;
}
.mouse-face {
position: absolute;
top: 10px;
left: 5px;
width: 30px;
height: 30px;
background-color: var(--secondary-color);
border-radius: 50%;
}
</style>
</head>
<body>
<div id="error-toast" class="toast"></div>
<div class="container">
<div class="header">
<div class="logo">
<div class="mouse-logo">
<div class="mouse-ears mouse-ear-left"></div>
<div class="mouse-ears mouse-ear-right"></div>
<div class="mouse-face"></div>
</div>
<h1>MOUSE ์Œ์„ฑ ์ฑ—</h1>
</div>
<div class="search-toggle">
<span class="search-label">์›น ๊ฒ€์ƒ‰</span>
<div id="search-toggle" class="toggle-switch">
<div class="toggle-slider"></div>
</div>
</div>
<div class="status-indicator">
<div id="status-dot" class="status-dot disconnected"></div>
<span id="status-text">์—ฐ๊ฒฐ ๋Œ€๊ธฐ ์ค‘</span>
</div>
</div>
<div class="chat-container">
<div class="chat-messages" id="chat-messages"></div>
</div>
<div class="controls">
<button id="start-button">๋Œ€ํ™” ์‹œ์ž‘</button>
</div>
</div>
<audio id="audio-output"></audio>
<script>
let peerConnection;
let webrtc_id;
let webSearchEnabled = false;
let reconnectAttempts = 0;
let heartbeatInterval;
let connectionHealthInterval;
const audioOutput = document.getElementById('audio-output');
const startButton = document.getElementById('start-button');
const chatMessages = document.getElementById('chat-messages');
const statusDot = document.getElementById('status-dot');
const statusText = document.getElementById('status-text');
const searchToggle = document.getElementById('search-toggle');
let audioLevel = 0;
let animationFrame;
let audioContext, analyser, audioSource;
// Web search toggle functionality
searchToggle.addEventListener('click', () => {
webSearchEnabled = !webSearchEnabled;
searchToggle.classList.toggle('active', webSearchEnabled);
console.log('Web search enabled:', webSearchEnabled);
});
function updateStatus(state) {
statusDot.className = 'status-dot ' + state;
if (state === 'connected') {
statusText.textContent = '์—ฐ๊ฒฐ๋จ';
} else if (state === 'connecting') {
statusText.textContent = '์—ฐ๊ฒฐ ์ค‘...';
} else {
statusText.textContent = '์—ฐ๊ฒฐ ๋Œ€๊ธฐ ์ค‘';
}
}
function updateButtonState() {
const button = document.getElementById('start-button');
if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
button.innerHTML = `
<div class="icon-with-spinner">
<div class="spinner"></div>
<span>์—ฐ๊ฒฐ ์ค‘...</span>
</div>
`;
updateStatus('connecting');
} else if (peerConnection && peerConnection.connectionState === 'connected') {
button.innerHTML = `
<div class="icon-with-spinner">
<div class="audio-visualizer" id="audio-visualizer">
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
</div>
<span>๋Œ€ํ™” ์ข…๋ฃŒ</span>
</div>
`;
updateStatus('connected');
} else {
button.innerHTML = '๋Œ€ํ™” ์‹œ์ž‘';
updateStatus('disconnected');
}
}
function setupAudioVisualization(stream) {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
audioSource = audioContext.createMediaStreamSource(stream);
audioSource.connect(analyser);
analyser.fftSize = 256;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
const visualizerBars = document.querySelectorAll('.visualizer-bar');
const barCount = visualizerBars.length;
function updateAudioLevel() {
analyser.getByteFrequencyData(dataArray);
for (let i = 0; i < barCount; i++) {
const start = Math.floor(i * (bufferLength / barCount));
const end = Math.floor((i + 1) * (bufferLength / barCount));
let sum = 0;
for (let j = start; j < end; j++) {
sum += dataArray[j];
}
const average = sum / (end - start) / 255;
const scaleY = 0.1 + average * 0.9;
visualizerBars[i].style.transform = `scaleY(${scaleY})`;
}
animationFrame = requestAnimationFrame(updateAudioLevel);
}
updateAudioLevel();
}
function showError(message) {
const toast = document.getElementById('error-toast');
toast.textContent = message;
toast.className = 'toast error';
toast.style.display = 'block';
setTimeout(() => {
toast.style.display = 'none';
}, 5000);
}
// ์—ฐ๊ฒฐ ์ƒํƒœ ๋ชจ๋‹ˆํ„ฐ๋ง ํ•จ์ˆ˜
function startConnectionHealthCheck() {
if (connectionHealthInterval) {
clearInterval(connectionHealthInterval);
}
connectionHealthInterval = setInterval(() => {
if (peerConnection) {
const state = peerConnection.connectionState;
const iceState = peerConnection.iceConnectionState;
console.log(`Connection state: ${state}, ICE state: ${iceState}`);
if (state === 'failed' || state === 'closed' || iceState === 'failed') {
console.log('Connection lost, attempting to reconnect...');
handleConnectionLoss();
}
}
}, 3000); // 3์ดˆ๋งˆ๋‹ค ์ฒดํฌ
}
// ์—ฐ๊ฒฐ ์†์‹ค ์ฒ˜๋ฆฌ
function handleConnectionLoss() {
if (reconnectAttempts < 3) {
reconnectAttempts++;
showError(`์—ฐ๊ฒฐ์ด ๋Š์–ด์กŒ์Šต๋‹ˆ๋‹ค. ์žฌ์—ฐ๊ฒฐ ์‹œ๋„ ์ค‘... (${reconnectAttempts}/3)`);
stop();
setTimeout(() => {
setupWebRTC();
}, 2000);
} else {
showError('์—ฐ๊ฒฐ์„ ๋ณต๊ตฌํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ƒˆ๋กœ๊ณ ์นจ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”.');
stop();
}
}
async function setupWebRTC() {
const config = __RTC_CONFIGURATION__;
peerConnection = new RTCPeerConnection(config);
const timeoutId = setTimeout(() => {
const toast = document.getElementById('error-toast');
toast.textContent = "์—ฐ๊ฒฐ์ด ํ‰์†Œ๋ณด๋‹ค ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. VPN์„ ์‚ฌ์šฉ ์ค‘์ด์‹ ๊ฐ€์š”?";
toast.className = 'toast warning';
toast.style.display = 'block';
setTimeout(() => {
toast.style.display = 'none';
}, 5000);
}, 5000);
try {
const stream = await navigator.mediaDevices.getUserMedia({
audio: true
});
setupAudioVisualization(stream);
stream.getTracks().forEach(track => {
peerConnection.addTrack(track, stream);
});
peerConnection.addEventListener('track', (evt) => {
if (audioOutput.srcObject !== evt.streams[0]) {
audioOutput.srcObject = evt.streams[0];
audioOutput.play();
}
});
const dataChannel = peerConnection.createDataChannel('text');
// Heartbeat ๋ฉ”์‹œ์ง€ ์ „์†ก
dataChannel.onopen = () => {
console.log('Data channel opened');
if (heartbeatInterval) clearInterval(heartbeatInterval);
heartbeatInterval = setInterval(() => {
if (dataChannel.readyState === 'open') {
dataChannel.send(JSON.stringify({ type: 'heartbeat' }));
}
}, 30000); // 30์ดˆ๋งˆ๋‹ค heartbeat
};
dataChannel.onmessage = (event) => {
const eventJson = JSON.parse(event.data);
if (eventJson.type === "error") {
showError(eventJson.message);
} else if (eventJson.type === "connection_lost") {
handleConnectionLoss();
}
};
dataChannel.onclose = () => {
console.log('Data channel closed');
if (heartbeatInterval) clearInterval(heartbeatInterval);
};
const offer = await peerConnection.createOffer();
await peerConnection.setLocalDescription(offer);
await new Promise((resolve) => {
if (peerConnection.iceGatheringState === "complete") {
resolve();
} else {
const checkState = () => {
if (peerConnection.iceGatheringState === "complete") {
peerConnection.removeEventListener("icegatheringstatechange", checkState);
resolve();
}
};
peerConnection.addEventListener("icegatheringstatechange", checkState);
}
});
// ๋ชจ๋“  ์—ฐ๊ฒฐ ์ƒํƒœ ์ด๋ฒคํŠธ ๋ชจ๋‹ˆํ„ฐ๋ง
peerConnection.addEventListener('connectionstatechange', () => {
console.log('connectionstatechange', peerConnection.connectionState);
if (peerConnection.connectionState === 'connected') {
clearTimeout(timeoutId);
const toast = document.getElementById('error-toast');
toast.style.display = 'none';
reconnectAttempts = 0;
startConnectionHealthCheck();
} else if (peerConnection.connectionState === 'failed') {
handleConnectionLoss();
}
updateButtonState();
});
peerConnection.addEventListener('iceconnectionstatechange', () => {
console.log('ICE connection state:', peerConnection.iceConnectionState);
if (peerConnection.iceConnectionState === 'disconnected') {
showError('๋„คํŠธ์›Œํฌ ์—ฐ๊ฒฐ์ด ๋ถˆ์•ˆ์ •ํ•ฉ๋‹ˆ๋‹ค');
} else if (peerConnection.iceConnectionState === 'failed') {
handleConnectionLoss();
}
});
webrtc_id = Math.random().toString(36).substring(7);
const response = await fetch('/webrtc/offer', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
sdp: peerConnection.localDescription.sdp,
type: peerConnection.localDescription.type,
webrtc_id: webrtc_id,
web_search_enabled: webSearchEnabled
})
});
const serverResponse = await response.json();
if (serverResponse.status === 'failed') {
showError(serverResponse.meta.error === 'concurrency_limit_reached'
? `๋„ˆ๋ฌด ๋งŽ์€ ์—ฐ๊ฒฐ์ž…๋‹ˆ๋‹ค. ์ตœ๋Œ€ ํ•œ๋„๋Š” ${serverResponse.meta.limit} ์ž…๋‹ˆ๋‹ค.`
: serverResponse.meta.error);
stop();
return;
}
await peerConnection.setRemoteDescription(serverResponse);
const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
eventSource.addEventListener("output", (event) => {
const eventJson = JSON.parse(event.data);
addMessage("assistant", eventJson.content);
});
eventSource.addEventListener("search", (event) => {
const eventJson = JSON.parse(event.data);
if (eventJson.query) {
addMessage("search-result", `์›น ๊ฒ€์ƒ‰ ์ค‘: "${eventJson.query}"`);
}
});
eventSource.addEventListener("error", (event) => {
console.error('EventSource error:', event);
handleConnectionLoss();
});
} catch (err) {
clearTimeout(timeoutId);
console.error('Error setting up WebRTC:', err);
showError('์—ฐ๊ฒฐ์„ ์„ค์ •ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•ด ์ฃผ์„ธ์š”.');
stop();
}
}
function addMessage(role, content) {
const messageDiv = document.createElement('div');
messageDiv.classList.add('message', role);
messageDiv.textContent = content;
chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight;
}
function stop() {
if (heartbeatInterval) {
clearInterval(heartbeatInterval);
}
if (connectionHealthInterval) {
clearInterval(connectionHealthInterval);
}
if (animationFrame) {
cancelAnimationFrame(animationFrame);
}
if (audioContext) {
audioContext.close();
audioContext = null;
analyser = null;
audioSource = null;
}
if (peerConnection) {
if (peerConnection.getTransceivers) {
peerConnection.getTransceivers().forEach(transceiver => {
if (transceiver.stop) {
transceiver.stop();
}
});
}
if (peerConnection.getSenders) {
peerConnection.getSenders().forEach(sender => {
if (sender.track && sender.track.stop) sender.track.stop();
});
}
console.log('closing');
peerConnection.close();
}
updateButtonState();
audioLevel = 0;
}
startButton.addEventListener('click', () => {
console.log('clicked');
console.log(peerConnection, peerConnection?.connectionState);
if (!peerConnection || peerConnection.connectionState !== 'connected') {
setupWebRTC();
} else {
console.log('stopping');
stop();
}
});
</script>
</body>
</html>"""
class BraveSearchClient:
"""Brave Search API client"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.search.brave.com/res/v1/web/search"
async def search(self, query: str, count: int = 10) -> List[Dict]:
"""Perform a web search using Brave Search API"""
if not self.api_key:
return []
headers = {
"Accept": "application/json",
"X-Subscription-Token": self.api_key
}
params = {
"q": query,
"count": count,
"lang": "ko"
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(self.base_url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
results = []
if "web" in data and "results" in data["web"]:
for result in data["web"]["results"][:count]:
results.append({
"title": result.get("title", ""),
"url": result.get("url", ""),
"description": result.get("description", "")
})
return results
except Exception as e:
logger.error(f"Brave Search error: {e}")
return []
# Initialize search client globally
brave_api_key = os.getenv("BSEARCH_API")
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
logger.info(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
# Store web search settings by connection
web_search_settings = {}
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
class OpenAIHandler(AsyncStreamHandler):
def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None:
super().__init__(
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.output_queue = asyncio.Queue()
self.search_client = search_client
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
self.webrtc_id = webrtc_id
self.web_search_enabled = web_search_enabled
self.keep_alive_task = None
self.last_activity = datetime.now()
self.connection_active = True
logger.info(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
def copy(self):
# Get the most recent settings
if web_search_settings:
# Get the most recent webrtc_id
recent_ids = sorted(web_search_settings.keys(),
key=lambda k: web_search_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = web_search_settings[recent_id]
web_search_enabled = settings.get('enabled', False)
logger.info(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}")
return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id)
logger.info(f"Handler.copy() called - creating new handler with default settings")
return OpenAIHandler(web_search_enabled=False)
async def search_web(self, query: str) -> str:
"""Perform web search and return formatted results"""
if not self.search_client or not self.web_search_enabled:
return "์›น ๊ฒ€์ƒ‰์ด ๋น„ํ™œ์„ฑํ™”๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค."
logger.info(f"Searching web for: {query}")
results = await self.search_client.search(query)
if not results:
return f"'{query}'์— ๋Œ€ํ•œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# Format search results
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(
f"{i}. {result['title']}\n"
f" URL: {result['url']}\n"
f" {result['description']}\n"
)
return f"์›น ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ '{query}':\n\n" + "\n".join(formatted_results)
async def keep_alive(self):
"""Keep the connection alive with periodic activity checks"""
while self.connection_active:
try:
await asyncio.sleep(30) # 30์ดˆ๋งˆ๋‹ค ์ฒดํฌ
# ๋งˆ์ง€๋ง‰ ํ™œ๋™์œผ๋กœ๋ถ€ํ„ฐ 5๋ถ„์ด ์ง€๋‚ฌ๋Š”์ง€ ํ™•์ธ
inactive_time = (datetime.now() - self.last_activity).total_seconds()
if inactive_time > 300: # 5๋ถ„
logger.warning(f"Connection inactive for {inactive_time} seconds")
# ์—ฐ๊ฒฐ ์ƒํƒœ ํ™•์ธ
if self.connection:
logger.debug("Connection alive - sending keepalive")
# OpenAI ์—ฐ๊ฒฐ์€ ์ž๋™์œผ๋กœ ์œ ์ง€๋จ
else:
logger.error("Connection lost in keep_alive")
self.connection_active = False
break
except Exception as e:
logger.error(f"Keep-alive error: {e}")
self.connection_active = False
break
async def start_up(self):
"""Connect to realtime API with function calling enabled"""
# First check if we have the most recent settings
if web_search_settings:
recent_ids = sorted(web_search_settings.keys(),
key=lambda k: web_search_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = web_search_settings[recent_id]
self.web_search_enabled = settings.get('enabled', False)
self.webrtc_id = recent_id
logger.info(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}")
logger.info(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
self.client = openai.AsyncOpenAI()
self.connection_active = True
# Define the web search function
tools = []
instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean."
if self.web_search_enabled and self.search_client:
tools = [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query in Korean or English"
}
},
"required": ["query"]
}
}
}]
logger.info("Web search function added to tools")
instructions = (
"You are a helpful assistant with web search capabilities. "
"IMPORTANT: You MUST use the web_search function for ANY of these topics:\n"
"- Weather (๋‚ ์”จ, ๊ธฐ์˜จ, ๋น„, ๋ˆˆ)\n"
"- News (๋‰ด์Šค, ์†Œ์‹)\n"
"- Current events (ํ˜„์žฌ, ์ตœ๊ทผ, ์˜ค๋Š˜, ์ง€๊ธˆ)\n"
"- Prices (๊ฐ€๊ฒฉ, ํ™˜์œจ, ์ฃผ๊ฐ€)\n"
"- Sports scores or results\n"
"- Any question about 2024 or 2025\n"
"- Any time-sensitive information\n\n"
"When in doubt, USE web_search. It's better to search and provide accurate information "
"than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
)
try:
async with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
# Update session with tools
session_update = {
"turn_detection": {"type": "server_vad"},
"instructions": instructions,
"tools": tools,
"tool_choice": "auto" if tools else "none"
}
await conn.session.update(session=session_update)
self.connection = conn
self.last_activity = datetime.now()
logger.info(f"Connected with tools: {len(tools)} functions")
# Start keep-alive task
self.keep_alive_task = asyncio.create_task(self.keep_alive())
async for event in self.connection:
self.last_activity = datetime.now()
# Debug logging for function calls
if event.type.startswith("response.function_call"):
logger.debug(f"Function event: {event.type}")
if event.type == "response.audio_transcript.done":
await self.output_queue.put(AdditionalOutputs(event))
elif event.type == "response.audio.delta":
await self.output_queue.put(
(
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
),
)
# Handle function calls
elif event.type == "response.function_call_arguments.start":
logger.info(f"Function call started")
self.function_call_in_progress = True
self.current_function_args = ""
self.current_call_id = getattr(event, 'call_id', None)
elif event.type == "response.function_call_arguments.delta":
if self.function_call_in_progress:
self.current_function_args += event.delta
elif event.type == "response.function_call_arguments.done":
if self.function_call_in_progress:
logger.info(f"Function call done, args: {self.current_function_args}")
try:
args = json.loads(self.current_function_args)
query = args.get("query", "")
# Emit search event to client
await self.output_queue.put(AdditionalOutputs({
"type": "search",
"query": query
}))
# Perform the search
search_results = await self.search_web(query)
logger.info(f"Search results length: {len(search_results)}")
# Send function result back to the model
if self.connection and self.current_call_id:
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": self.current_call_id,
"output": search_results
}
)
await self.connection.response.create()
except Exception as e:
logger.error(f"Function call error: {e}")
finally:
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
except Exception as e:
logger.error(f"Connection error in start_up: {e}")
self.connection_active = False
# ์—ฐ๊ฒฐ ์˜ค๋ฅ˜๋ฅผ ํด๋ผ์ด์–ธํŠธ์— ์•Œ๋ฆผ
await self.output_queue.put(AdditionalOutputs({
"type": "connection_lost",
"message": "์„œ๋ฒ„ ์—ฐ๊ฒฐ์ด ๋Š์–ด์กŒ์Šต๋‹ˆ๋‹ค"
}))
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.connection or not self.connection_active:
return
try:
self.last_activity = datetime.now()
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
await self.connection.input_audio_buffer.append(audio=audio_message)
except Exception as e:
logger.error(f"Error in receive: {e}")
# ์—ฐ๊ฒฐ์ด ๋Š์–ด์ง„ ๊ฒฝ์šฐ ์ƒํƒœ ์—…๋ฐ์ดํŠธ
if "closed" in str(e).lower() or "connection" in str(e).lower():
self.connection = None
self.connection_active = False
# ํด๋ผ์ด์–ธํŠธ์— ์—ฐ๊ฒฐ ์ข…๋ฃŒ ์•Œ๋ฆผ
await self.output_queue.put(AdditionalOutputs({
"type": "connection_lost",
"message": "์—ฐ๊ฒฐ์ด ์ข…๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค"
}))
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
return await wait_for_item(self.output_queue)
async def shutdown(self) -> None:
logger.info("Shutting down handler")
self.connection_active = False
if self.keep_alive_task:
self.keep_alive_task.cancel()
try:
await self.keep_alive_task
except asyncio.CancelledError:
pass
if self.connection:
await self.connection.close()
self.connection = None
# Create initial handler instance
handler = OpenAIHandler(web_search_enabled=False)
# Create components
chatbot = gr.Chatbot(type="messages")
# Create stream with handler instance - ์‹œ๊ฐ„ ์ œํ•œ ์ œ๊ฑฐ
stream = Stream(
handler, # Pass instance, not factory
mode="send-receive",
modality="audio",
additional_inputs=[chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=update_chatbot,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=None, # ์‹œ๊ฐ„ ์ œํ•œ ์ œ๊ฑฐ
)
app = FastAPI()
# Mount stream
stream.mount(app)
# Intercept offer to capture settings
@app.post("/webrtc/offer", include_in_schema=False)
async def custom_offer(request: Request):
"""Intercept offer to capture web search settings"""
body = await request.json()
webrtc_id = body.get("webrtc_id")
web_search_enabled = body.get("web_search_enabled", False)
logger.info(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
# Store settings with timestamp
if webrtc_id:
web_search_settings[webrtc_id] = {
'enabled': web_search_enabled,
'timestamp': asyncio.get_event_loop().time()
}
# Remove our custom route temporarily
custom_route = None
for i, route in enumerate(app.routes):
if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer:
custom_route = app.routes.pop(i)
break
# Forward to stream's offer handler
response = await stream.offer(body)
# Re-add our custom route
if custom_route:
app.routes.insert(0, custom_route)
return response
@app.get("/outputs")
async def outputs(webrtc_id: str):
"""Stream outputs including search events"""
async def output_stream():
async for output in stream.output_stream(webrtc_id):
if hasattr(output, 'args') and output.args:
# Check if it's a search event or connection lost event
if isinstance(output.args[0], dict):
event_type = output.args[0].get('type')
if event_type == 'search':
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
elif event_type == 'connection_lost':
yield f"event: error\ndata: {json.dumps(output.args[0])}\n\n"
# Regular transcript event
elif hasattr(output.args[0], 'transcript'):
s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
yield f"event: output\ndata: {s}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
@app.get("/")
async def index():
"""Serve the HTML page"""
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import uvicorn
mode = os.getenv("MODE")
if mode == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)