Spaces:
Running
Running
import asyncio | |
import base64 | |
import json | |
from pathlib import Path | |
import os | |
import numpy as np | |
import openai | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, StreamingResponse | |
from fastrtc import ( | |
AdditionalOutputs, | |
AsyncStreamHandler, | |
Stream, | |
get_twilio_turn_credentials, | |
wait_for_item, | |
) | |
from gradio.utils import get_space | |
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent | |
import httpx | |
from typing import Optional, List, Dict | |
import gradio as gr | |
load_dotenv() | |
SAMPLE_RATE = 24000 | |
# HTML content embedded as a string | |
HTML_CONTENT = """<!DOCTYPE html> | |
<html lang="ko"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>MOUSE 음성 챗</title> | |
<style> | |
:root { | |
--primary-color: #6f42c1; | |
--secondary-color: #563d7c; | |
--dark-bg: #121212; | |
--card-bg: #1e1e1e; | |
--text-color: #f8f9fa; | |
--border-color: #333; | |
--hover-color: #8a5cf6; | |
} | |
body { | |
font-family: "SF Pro Display", -apple-system, BlinkMacSystemFont, sans-serif; | |
background-color: var(--dark-bg); | |
color: var(--text-color); | |
margin: 0; | |
padding: 0; | |
height: 100vh; | |
display: flex; | |
flex-direction: column; | |
overflow: hidden; | |
} | |
.container { | |
max-width: 900px; | |
margin: 0 auto; | |
padding: 20px; | |
flex-grow: 1; | |
display: flex; | |
flex-direction: column; | |
width: 100%; | |
height: calc(100vh - 40px); | |
box-sizing: border-box; | |
} | |
.header { | |
text-align: center; | |
padding: 20px 0; | |
border-bottom: 1px solid var(--border-color); | |
margin-bottom: 20px; | |
flex-shrink: 0; | |
} | |
.logo { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 10px; | |
} | |
.logo h1 { | |
margin: 0; | |
background: linear-gradient(135deg, var(--primary-color), #a78bfa); | |
-webkit-background-clip: text; | |
background-clip: text; | |
color: transparent; | |
font-size: 32px; | |
letter-spacing: 1px; | |
} | |
/* Web search toggle */ | |
.search-toggle { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 10px; | |
margin-top: 15px; | |
} | |
.toggle-switch { | |
position: relative; | |
width: 50px; | |
height: 26px; | |
background-color: #ccc; | |
border-radius: 13px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.toggle-switch.active { | |
background-color: var(--primary-color); | |
} | |
.toggle-slider { | |
position: absolute; | |
top: 3px; | |
left: 3px; | |
width: 20px; | |
height: 20px; | |
background-color: white; | |
border-radius: 50%; | |
transition: transform 0.3s; | |
} | |
.toggle-switch.active .toggle-slider { | |
transform: translateX(24px); | |
} | |
.search-label { | |
font-size: 14px; | |
color: #aaa; | |
} | |
.chat-container { | |
border-radius: 12px; | |
background-color: var(--card-bg); | |
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); | |
padding: 20px; | |
flex-grow: 1; | |
display: flex; | |
flex-direction: column; | |
border: 1px solid var(--border-color); | |
overflow: hidden; | |
min-height: 0; | |
} | |
.chat-messages { | |
flex-grow: 1; | |
overflow-y: auto; | |
padding: 10px; | |
scrollbar-width: thin; | |
scrollbar-color: var(--primary-color) var(--card-bg); | |
min-height: 0; | |
} | |
.chat-messages::-webkit-scrollbar { | |
width: 6px; | |
} | |
.chat-messages::-webkit-scrollbar-thumb { | |
background-color: var(--primary-color); | |
border-radius: 6px; | |
} | |
.message { | |
margin-bottom: 20px; | |
padding: 14px; | |
border-radius: 8px; | |
font-size: 16px; | |
line-height: 1.6; | |
position: relative; | |
max-width: 80%; | |
animation: fade-in 0.3s ease-out; | |
} | |
@keyframes fade-in { | |
from { | |
opacity: 0; | |
transform: translateY(10px); | |
} | |
to { | |
opacity: 1; | |
transform: translateY(0); | |
} | |
} | |
.message.user { | |
background: linear-gradient(135deg, #2c3e50, #34495e); | |
margin-left: auto; | |
border-bottom-right-radius: 2px; | |
} | |
.message.assistant { | |
background: linear-gradient(135deg, var(--secondary-color), var(--primary-color)); | |
margin-right: auto; | |
border-bottom-left-radius: 2px; | |
} | |
.message.search-result { | |
background: linear-gradient(135deg, #1a5a3e, #2e7d32); | |
font-size: 14px; | |
padding: 10px; | |
margin-bottom: 10px; | |
} | |
.controls { | |
text-align: center; | |
margin-top: 20px; | |
display: flex; | |
justify-content: center; | |
flex-shrink: 0; | |
} | |
button { | |
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); | |
color: white; | |
border: none; | |
padding: 14px 28px; | |
font-family: inherit; | |
font-size: 16px; | |
cursor: pointer; | |
transition: all 0.3s; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
border-radius: 50px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 10px; | |
box-shadow: 0 4px 10px rgba(111, 66, 193, 0.3); | |
} | |
button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 15px rgba(111, 66, 193, 0.5); | |
background: linear-gradient(135deg, var(--hover-color), var(--primary-color)); | |
} | |
button:active { | |
transform: translateY(1px); | |
} | |
#audio-output { | |
display: none; | |
} | |
.icon-with-spinner { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 12px; | |
min-width: 180px; | |
} | |
.spinner { | |
width: 20px; | |
height: 20px; | |
border: 2px solid #ffffff; | |
border-top-color: transparent; | |
border-radius: 50%; | |
animation: spin 1s linear infinite; | |
flex-shrink: 0; | |
} | |
@keyframes spin { | |
to { | |
transform: rotate(360deg); | |
} | |
} | |
.audio-visualizer { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 5px; | |
min-width: 80px; | |
height: 25px; | |
} | |
.visualizer-bar { | |
width: 4px; | |
height: 100%; | |
background-color: rgba(255, 255, 255, 0.7); | |
border-radius: 2px; | |
transform-origin: bottom; | |
transform: scaleY(0.1); | |
transition: transform 0.1s ease; | |
} | |
.toast { | |
position: fixed; | |
top: 20px; | |
left: 50%; | |
transform: translateX(-50%); | |
padding: 16px 24px; | |
border-radius: 8px; | |
font-size: 14px; | |
z-index: 1000; | |
display: none; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
} | |
.toast.error { | |
background-color: #f44336; | |
color: white; | |
} | |
.toast.warning { | |
background-color: #ff9800; | |
color: white; | |
} | |
.status-indicator { | |
display: inline-flex; | |
align-items: center; | |
margin-top: 10px; | |
font-size: 14px; | |
color: #aaa; | |
} | |
.status-dot { | |
width: 8px; | |
height: 8px; | |
border-radius: 50%; | |
margin-right: 8px; | |
} | |
.status-dot.connected { | |
background-color: #4caf50; | |
} | |
.status-dot.disconnected { | |
background-color: #f44336; | |
} | |
.status-dot.connecting { | |
background-color: #ff9800; | |
animation: pulse 1.5s infinite; | |
} | |
@keyframes pulse { | |
0% { | |
opacity: 0.6; | |
} | |
50% { | |
opacity: 1; | |
} | |
100% { | |
opacity: 0.6; | |
} | |
} | |
.mouse-logo { | |
position: relative; | |
width: 40px; | |
height: 40px; | |
} | |
.mouse-ears { | |
position: absolute; | |
width: 15px; | |
height: 15px; | |
background-color: var(--primary-color); | |
border-radius: 50%; | |
} | |
.mouse-ear-left { | |
top: 0; | |
left: 5px; | |
} | |
.mouse-ear-right { | |
top: 0; | |
right: 5px; | |
} | |
.mouse-face { | |
position: absolute; | |
top: 10px; | |
left: 5px; | |
width: 30px; | |
height: 30px; | |
background-color: var(--secondary-color); | |
border-radius: 50%; | |
} | |
</style> | |
</head> | |
<body> | |
<div id="error-toast" class="toast"></div> | |
<div class="container"> | |
<div class="header"> | |
<div class="logo"> | |
<div class="mouse-logo"> | |
<div class="mouse-ears mouse-ear-left"></div> | |
<div class="mouse-ears mouse-ear-right"></div> | |
<div class="mouse-face"></div> | |
</div> | |
<h1>MOUSE 음성 챗</h1> | |
</div> | |
<div class="search-toggle"> | |
<span class="search-label">웹 검색</span> | |
<div id="search-toggle" class="toggle-switch"> | |
<div class="toggle-slider"></div> | |
</div> | |
</div> | |
<div class="status-indicator"> | |
<div id="status-dot" class="status-dot disconnected"></div> | |
<span id="status-text">연결 대기 중</span> | |
</div> | |
</div> | |
<div class="chat-container"> | |
<div class="chat-messages" id="chat-messages"></div> | |
</div> | |
<div class="controls"> | |
<button id="start-button">대화 시작</button> | |
</div> | |
</div> | |
<audio id="audio-output"></audio> | |
<script> | |
let peerConnection; | |
let webrtc_id; | |
let webSearchEnabled = false; | |
const audioOutput = document.getElementById('audio-output'); | |
const startButton = document.getElementById('start-button'); | |
const chatMessages = document.getElementById('chat-messages'); | |
const statusDot = document.getElementById('status-dot'); | |
const statusText = document.getElementById('status-text'); | |
const searchToggle = document.getElementById('search-toggle'); | |
let audioLevel = 0; | |
let animationFrame; | |
let audioContext, analyser, audioSource; | |
// Web search toggle functionality | |
searchToggle.addEventListener('click', () => { | |
webSearchEnabled = !webSearchEnabled; | |
searchToggle.classList.toggle('active', webSearchEnabled); | |
console.log('Web search enabled:', webSearchEnabled); | |
}); | |
function updateStatus(state) { | |
statusDot.className = 'status-dot ' + state; | |
if (state === 'connected') { | |
statusText.textContent = '연결됨'; | |
} else if (state === 'connecting') { | |
statusText.textContent = '연결 중...'; | |
} else { | |
statusText.textContent = '연결 대기 중'; | |
} | |
} | |
function updateButtonState() { | |
const button = document.getElementById('start-button'); | |
if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) { | |
button.innerHTML = ` | |
<div class="icon-with-spinner"> | |
<div class="spinner"></div> | |
<span>연결 중...</span> | |
</div> | |
`; | |
updateStatus('connecting'); | |
} else if (peerConnection && peerConnection.connectionState === 'connected') { | |
button.innerHTML = ` | |
<div class="icon-with-spinner"> | |
<div class="audio-visualizer" id="audio-visualizer"> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
</div> | |
<span>대화 종료</span> | |
</div> | |
`; | |
updateStatus('connected'); | |
} else { | |
button.innerHTML = '대화 시작'; | |
updateStatus('disconnected'); | |
} | |
} | |
function setupAudioVisualization(stream) { | |
audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
analyser = audioContext.createAnalyser(); | |
audioSource = audioContext.createMediaStreamSource(stream); | |
audioSource.connect(analyser); | |
analyser.fftSize = 256; | |
const bufferLength = analyser.frequencyBinCount; | |
const dataArray = new Uint8Array(bufferLength); | |
const visualizerBars = document.querySelectorAll('.visualizer-bar'); | |
const barCount = visualizerBars.length; | |
function updateAudioLevel() { | |
analyser.getByteFrequencyData(dataArray); | |
for (let i = 0; i < barCount; i++) { | |
const start = Math.floor(i * (bufferLength / barCount)); | |
const end = Math.floor((i + 1) * (bufferLength / barCount)); | |
let sum = 0; | |
for (let j = start; j < end; j++) { | |
sum += dataArray[j]; | |
} | |
const average = sum / (end - start) / 255; | |
const scaleY = 0.1 + average * 0.9; | |
visualizerBars[i].style.transform = `scaleY(${scaleY})`; | |
} | |
animationFrame = requestAnimationFrame(updateAudioLevel); | |
} | |
updateAudioLevel(); | |
} | |
function showError(message) { | |
const toast = document.getElementById('error-toast'); | |
toast.textContent = message; | |
toast.className = 'toast error'; | |
toast.style.display = 'block'; | |
setTimeout(() => { | |
toast.style.display = 'none'; | |
}, 5000); | |
} | |
async function setupWebRTC() { | |
const config = __RTC_CONFIGURATION__; | |
peerConnection = new RTCPeerConnection(config); | |
const timeoutId = setTimeout(() => { | |
const toast = document.getElementById('error-toast'); | |
toast.textContent = "연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?"; | |
toast.className = 'toast warning'; | |
toast.style.display = 'block'; | |
setTimeout(() => { | |
toast.style.display = 'none'; | |
}, 5000); | |
}, 5000); | |
try { | |
const stream = await navigator.mediaDevices.getUserMedia({ | |
audio: true | |
}); | |
setupAudioVisualization(stream); | |
stream.getTracks().forEach(track => { | |
peerConnection.addTrack(track, stream); | |
}); | |
peerConnection.addEventListener('track', (evt) => { | |
if (audioOutput.srcObject !== evt.streams[0]) { | |
audioOutput.srcObject = evt.streams[0]; | |
audioOutput.play(); | |
} | |
}); | |
const dataChannel = peerConnection.createDataChannel('text'); | |
dataChannel.onmessage = (event) => { | |
const eventJson = JSON.parse(event.data); | |
if (eventJson.type === "error") { | |
showError(eventJson.message); | |
} | |
}; | |
const offer = await peerConnection.createOffer(); | |
await peerConnection.setLocalDescription(offer); | |
await new Promise((resolve) => { | |
if (peerConnection.iceGatheringState === "complete") { | |
resolve(); | |
} else { | |
const checkState = () => { | |
if (peerConnection.iceGatheringState === "complete") { | |
peerConnection.removeEventListener("icegatheringstatechange", checkState); | |
resolve(); | |
} | |
}; | |
peerConnection.addEventListener("icegatheringstatechange", checkState); | |
} | |
}); | |
peerConnection.addEventListener('connectionstatechange', () => { | |
console.log('connectionstatechange', peerConnection.connectionState); | |
if (peerConnection.connectionState === 'connected') { | |
clearTimeout(timeoutId); | |
const toast = document.getElementById('error-toast'); | |
toast.style.display = 'none'; | |
} | |
updateButtonState(); | |
}); | |
webrtc_id = Math.random().toString(36).substring(7); | |
const response = await fetch('/webrtc/offer', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ | |
sdp: peerConnection.localDescription.sdp, | |
type: peerConnection.localDescription.type, | |
webrtc_id: webrtc_id, | |
web_search_enabled: webSearchEnabled | |
}) | |
}); | |
const serverResponse = await response.json(); | |
if (serverResponse.status === 'failed') { | |
showError(serverResponse.meta.error === 'concurrency_limit_reached' | |
? `너무 많은 연결입니다. 최대 한도는 ${serverResponse.meta.limit} 입니다.` | |
: serverResponse.meta.error); | |
stop(); | |
return; | |
} | |
await peerConnection.setRemoteDescription(serverResponse); | |
const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id); | |
eventSource.addEventListener("output", (event) => { | |
const eventJson = JSON.parse(event.data); | |
addMessage("assistant", eventJson.content); | |
}); | |
eventSource.addEventListener("search", (event) => { | |
const eventJson = JSON.parse(event.data); | |
if (eventJson.query) { | |
addMessage("search-result", `웹 검색 중: "${eventJson.query}"`); | |
} | |
}); | |
} catch (err) { | |
clearTimeout(timeoutId); | |
console.error('Error setting up WebRTC:', err); | |
showError('연결을 설정하지 못했습니다. 다시 시도해 주세요.'); | |
stop(); | |
} | |
} | |
function addMessage(role, content) { | |
const messageDiv = document.createElement('div'); | |
messageDiv.classList.add('message', role); | |
messageDiv.textContent = content; | |
chatMessages.appendChild(messageDiv); | |
chatMessages.scrollTop = chatMessages.scrollHeight; | |
} | |
function stop() { | |
if (animationFrame) { | |
cancelAnimationFrame(animationFrame); | |
} | |
if (audioContext) { | |
audioContext.close(); | |
audioContext = null; | |
analyser = null; | |
audioSource = null; | |
} | |
if (peerConnection) { | |
if (peerConnection.getTransceivers) { | |
peerConnection.getTransceivers().forEach(transceiver => { | |
if (transceiver.stop) { | |
transceiver.stop(); | |
} | |
}); | |
} | |
if (peerConnection.getSenders) { | |
peerConnection.getSenders().forEach(sender => { | |
if (sender.track && sender.track.stop) sender.track.stop(); | |
}); | |
} | |
console.log('closing'); | |
peerConnection.close(); | |
} | |
updateButtonState(); | |
audioLevel = 0; | |
} | |
startButton.addEventListener('click', () => { | |
console.log('clicked'); | |
console.log(peerConnection, peerConnection?.connectionState); | |
if (!peerConnection || peerConnection.connectionState !== 'connected') { | |
setupWebRTC(); | |
} else { | |
console.log('stopping'); | |
stop(); | |
} | |
}); | |
</script> | |
</body> | |
</html>""" | |
class BraveSearchClient: | |
"""Brave Search API client""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.base_url = "https://api.search.brave.com/res/v1/web/search" | |
async def search(self, query: str, count: int = 10) -> List[Dict]: | |
"""Perform a web search using Brave Search API""" | |
if not self.api_key: | |
return [] | |
headers = { | |
"Accept": "application/json", | |
"X-Subscription-Token": self.api_key | |
} | |
params = { | |
"q": query, | |
"count": count, | |
"lang": "ko" | |
} | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.get(self.base_url, headers=headers, params=params) | |
response.raise_for_status() | |
data = response.json() | |
results = [] | |
if "web" in data and "results" in data["web"]: | |
for result in data["web"]["results"][:count]: | |
results.append({ | |
"title": result.get("title", ""), | |
"url": result.get("url", ""), | |
"description": result.get("description", "") | |
}) | |
return results | |
except Exception as e: | |
print(f"Brave Search error: {e}") | |
return [] | |
# Initialize search client globally | |
brave_api_key = os.getenv("BSEARCH_API") | |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None | |
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}") | |
# Store web search settings by connection | |
web_search_settings = {} | |
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): | |
chatbot.append({"role": "assistant", "content": response.transcript}) | |
return chatbot | |
class OpenAIHandler(AsyncStreamHandler): | |
def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None: | |
super().__init__( | |
expected_layout="mono", | |
output_sample_rate=SAMPLE_RATE, | |
output_frame_size=480, | |
input_sample_rate=SAMPLE_RATE, | |
) | |
self.connection = None | |
self.output_queue = asyncio.Queue() | |
self.search_client = search_client | |
self.function_call_in_progress = False | |
self.current_function_args = "" | |
self.current_call_id = None | |
self.webrtc_id = webrtc_id | |
self.web_search_enabled = web_search_enabled | |
print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}") | |
def copy(self): | |
# Get the most recent settings | |
if web_search_settings: | |
# Get the most recent webrtc_id | |
recent_ids = sorted(web_search_settings.keys(), | |
key=lambda k: web_search_settings[k].get('timestamp', 0), | |
reverse=True) | |
if recent_ids: | |
recent_id = recent_ids[0] | |
settings = web_search_settings[recent_id] | |
web_search_enabled = settings.get('enabled', False) | |
print(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}") | |
return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id) | |
print(f"Handler.copy() called - creating new handler with default settings") | |
return OpenAIHandler(web_search_enabled=False) | |
async def search_web(self, query: str) -> str: | |
"""Perform web search and return formatted results""" | |
if not self.search_client or not self.web_search_enabled: | |
return "웹 검색이 비활성화되어 있습니다." | |
print(f"Searching web for: {query}") | |
results = await self.search_client.search(query) | |
if not results: | |
return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다." | |
# Format search results | |
formatted_results = [] | |
for i, result in enumerate(results, 1): | |
formatted_results.append( | |
f"{i}. {result['title']}\n" | |
f" URL: {result['url']}\n" | |
f" {result['description']}\n" | |
) | |
return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results) | |
async def start_up(self): | |
"""Connect to realtime API with function calling enabled""" | |
# First check if we have the most recent settings | |
if web_search_settings: | |
recent_ids = sorted(web_search_settings.keys(), | |
key=lambda k: web_search_settings[k].get('timestamp', 0), | |
reverse=True) | |
if recent_ids: | |
recent_id = recent_ids[0] | |
settings = web_search_settings[recent_id] | |
self.web_search_enabled = settings.get('enabled', False) | |
self.webrtc_id = recent_id | |
print(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}") | |
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}") | |
self.client = openai.AsyncOpenAI() | |
# Define the web search function | |
tools = [] | |
instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean." | |
if self.web_search_enabled and self.search_client: | |
tools = [{ | |
"type": "function", | |
"function": { | |
"name": "web_search", | |
"description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "The search query in Korean or English" | |
} | |
}, | |
"required": ["query"] | |
} | |
} | |
}] | |
print("Web search function added to tools") | |
instructions = ( | |
"You are a helpful assistant with web search capabilities. " | |
"IMPORTANT: You MUST use the web_search function for ANY of these topics:\n" | |
"- Weather (날씨, 기온, 비, 눈)\n" | |
"- News (뉴스, 소식)\n" | |
"- Current events (현재, 최근, 오늘, 지금)\n" | |
"- Prices (가격, 환율, 주가)\n" | |
"- Sports scores or results\n" | |
"- Any question about 2024 or 2025\n" | |
"- Any time-sensitive information\n\n" | |
"When in doubt, USE web_search. It's better to search and provide accurate information " | |
"than to guess or use outdated information. Always respond in Korean when the user speaks Korean." | |
) | |
async with self.client.beta.realtime.connect( | |
model="gpt-4o-mini-realtime-preview-2024-12-17" | |
) as conn: | |
# Update session with tools | |
session_update = { | |
"turn_detection": {"type": "server_vad"}, | |
"instructions": instructions, | |
"tools": tools, | |
"tool_choice": "auto" if tools else "none" | |
} | |
await conn.session.update(session=session_update) | |
self.connection = conn | |
print(f"Connected with tools: {len(tools)} functions") | |
async for event in self.connection: | |
# Debug logging for function calls | |
if event.type.startswith("response.function_call"): | |
print(f"Function event: {event.type}") | |
if event.type == "response.audio_transcript.done": | |
await self.output_queue.put(AdditionalOutputs(event)) | |
elif event.type == "response.audio.delta": | |
await self.output_queue.put( | |
( | |
self.output_sample_rate, | |
np.frombuffer( | |
base64.b64decode(event.delta), dtype=np.int16 | |
).reshape(1, -1), | |
), | |
) | |
# Handle function calls | |
elif event.type == "response.function_call_arguments.start": | |
print(f"Function call started") | |
self.function_call_in_progress = True | |
self.current_function_args = "" | |
self.current_call_id = getattr(event, 'call_id', None) | |
elif event.type == "response.function_call_arguments.delta": | |
if self.function_call_in_progress: | |
self.current_function_args += event.delta | |
elif event.type == "response.function_call_arguments.done": | |
if self.function_call_in_progress: | |
print(f"Function call done, args: {self.current_function_args}") | |
try: | |
args = json.loads(self.current_function_args) | |
query = args.get("query", "") | |
# Emit search event to client | |
await self.output_queue.put(AdditionalOutputs({ | |
"type": "search", | |
"query": query | |
})) | |
# Perform the search | |
search_results = await self.search_web(query) | |
print(f"Search results length: {len(search_results)}") | |
# Send function result back to the model | |
if self.connection and self.current_call_id: | |
await self.connection.conversation.item.create( | |
item={ | |
"type": "function_call_output", | |
"call_id": self.current_call_id, | |
"output": search_results | |
} | |
) | |
await self.connection.response.create() | |
except Exception as e: | |
print(f"Function call error: {e}") | |
finally: | |
self.function_call_in_progress = False | |
self.current_function_args = "" | |
self.current_call_id = None | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
if not self.connection: | |
return | |
try: | |
_, array = frame | |
array = array.squeeze() | |
audio_message = base64.b64encode(array.tobytes()).decode("utf-8") | |
await self.connection.input_audio_buffer.append(audio=audio_message) | |
except Exception as e: | |
print(f"Error in receive: {e}") | |
# Connection might be closed, ignore the error | |
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: | |
return await wait_for_item(self.output_queue) | |
async def shutdown(self) -> None: | |
if self.connection: | |
await self.connection.close() | |
self.connection = None | |
# Create initial handler instance | |
handler = OpenAIHandler(web_search_enabled=False) | |
# Create components | |
chatbot = gr.Chatbot(type="messages") | |
# Create stream with handler instance | |
stream = Stream( | |
handler, # Pass instance, not factory | |
mode="send-receive", | |
modality="audio", | |
additional_inputs=[chatbot], | |
additional_outputs=[chatbot], | |
additional_outputs_handler=update_chatbot, | |
rtc_configuration=get_twilio_turn_credentials() if get_space() else None, | |
concurrency_limit=5 if get_space() else None, | |
time_limit=300 if get_space() else None, | |
) | |
app = FastAPI() | |
# Mount stream | |
stream.mount(app) | |
# Intercept offer to capture settings | |
async def custom_offer(request: Request): | |
"""Intercept offer to capture web search settings""" | |
body = await request.json() | |
webrtc_id = body.get("webrtc_id") | |
web_search_enabled = body.get("web_search_enabled", False) | |
print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}") | |
# Store settings with timestamp | |
if webrtc_id: | |
web_search_settings[webrtc_id] = { | |
'enabled': web_search_enabled, | |
'timestamp': asyncio.get_event_loop().time() | |
} | |
# Remove our custom route temporarily | |
custom_route = None | |
for i, route in enumerate(app.routes): | |
if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer: | |
custom_route = app.routes.pop(i) | |
break | |
# Forward to stream's offer handler | |
response = await stream.offer(body) | |
# Re-add our custom route | |
if custom_route: | |
app.routes.insert(0, custom_route) | |
return response | |
async def outputs(webrtc_id: str): | |
"""Stream outputs including search events""" | |
async def output_stream(): | |
async for output in stream.output_stream(webrtc_id): | |
if hasattr(output, 'args') and output.args: | |
# Check if it's a search event | |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search': | |
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n" | |
# Regular transcript event | |
elif hasattr(output.args[0], 'transcript'): | |
s = json.dumps({"role": "assistant", "content": output.args[0].transcript}) | |
yield f"event: output\ndata: {s}\n\n" | |
return StreamingResponse(output_stream(), media_type="text/event-stream") | |
async def index(): | |
"""Serve the HTML page""" | |
rtc_config = get_twilio_turn_credentials() if get_space() else None | |
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) | |
return HTMLResponse(content=html_content) | |
if __name__ == "__main__": | |
import uvicorn | |
mode = os.getenv("MODE") | |
if mode == "UI": | |
stream.ui.launch(server_port=7860) | |
elif mode == "PHONE": | |
stream.fastphone(host="0.0.0.0", port=7860) | |
else: | |
uvicorn.run(app, host="0.0.0.0", port=7860) |