Private-AI / app-backup.py
seawolf2357's picture
Create app-backup.py
bac099b verified
raw
history blame
37.4 kB
import asyncio
import base64
import json
from pathlib import Path
import os
import numpy as np
import openai
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
AsyncStreamHandler,
Stream,
get_twilio_turn_credentials,
wait_for_item,
)
from gradio.utils import get_space
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
import httpx
from typing import Optional, List, Dict
import gradio as gr
load_dotenv()
SAMPLE_RATE = 24000
# HTML content embedded as a string
HTML_CONTENT = """<!DOCTYPE html>
<html lang="ko">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MOUSE 음성 챗</title>
<style>
:root {
--primary-color: #6f42c1;
--secondary-color: #563d7c;
--dark-bg: #121212;
--card-bg: #1e1e1e;
--text-color: #f8f9fa;
--border-color: #333;
--hover-color: #8a5cf6;
}
body {
font-family: "SF Pro Display", -apple-system, BlinkMacSystemFont, sans-serif;
background-color: var(--dark-bg);
color: var(--text-color);
margin: 0;
padding: 0;
height: 100vh;
display: flex;
flex-direction: column;
overflow: hidden;
}
.container {
max-width: 900px;
margin: 0 auto;
padding: 20px;
flex-grow: 1;
display: flex;
flex-direction: column;
width: 100%;
height: calc(100vh - 40px);
box-sizing: border-box;
}
.header {
text-align: center;
padding: 20px 0;
border-bottom: 1px solid var(--border-color);
margin-bottom: 20px;
flex-shrink: 0;
}
.logo {
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
}
.logo h1 {
margin: 0;
background: linear-gradient(135deg, var(--primary-color), #a78bfa);
-webkit-background-clip: text;
background-clip: text;
color: transparent;
font-size: 32px;
letter-spacing: 1px;
}
/* Web search toggle */
.search-toggle {
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
margin-top: 15px;
}
.toggle-switch {
position: relative;
width: 50px;
height: 26px;
background-color: #ccc;
border-radius: 13px;
cursor: pointer;
transition: background-color 0.3s;
}
.toggle-switch.active {
background-color: var(--primary-color);
}
.toggle-slider {
position: absolute;
top: 3px;
left: 3px;
width: 20px;
height: 20px;
background-color: white;
border-radius: 50%;
transition: transform 0.3s;
}
.toggle-switch.active .toggle-slider {
transform: translateX(24px);
}
.search-label {
font-size: 14px;
color: #aaa;
}
.chat-container {
border-radius: 12px;
background-color: var(--card-bg);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
padding: 20px;
flex-grow: 1;
display: flex;
flex-direction: column;
border: 1px solid var(--border-color);
overflow: hidden;
min-height: 0;
}
.chat-messages {
flex-grow: 1;
overflow-y: auto;
padding: 10px;
scrollbar-width: thin;
scrollbar-color: var(--primary-color) var(--card-bg);
min-height: 0;
}
.chat-messages::-webkit-scrollbar {
width: 6px;
}
.chat-messages::-webkit-scrollbar-thumb {
background-color: var(--primary-color);
border-radius: 6px;
}
.message {
margin-bottom: 20px;
padding: 14px;
border-radius: 8px;
font-size: 16px;
line-height: 1.6;
position: relative;
max-width: 80%;
animation: fade-in 0.3s ease-out;
}
@keyframes fade-in {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.message.user {
background: linear-gradient(135deg, #2c3e50, #34495e);
margin-left: auto;
border-bottom-right-radius: 2px;
}
.message.assistant {
background: linear-gradient(135deg, var(--secondary-color), var(--primary-color));
margin-right: auto;
border-bottom-left-radius: 2px;
}
.message.search-result {
background: linear-gradient(135deg, #1a5a3e, #2e7d32);
font-size: 14px;
padding: 10px;
margin-bottom: 10px;
}
.controls {
text-align: center;
margin-top: 20px;
display: flex;
justify-content: center;
flex-shrink: 0;
}
button {
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
color: white;
border: none;
padding: 14px 28px;
font-family: inherit;
font-size: 16px;
cursor: pointer;
transition: all 0.3s;
text-transform: uppercase;
letter-spacing: 1px;
border-radius: 50px;
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
box-shadow: 0 4px 10px rgba(111, 66, 193, 0.3);
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 6px 15px rgba(111, 66, 193, 0.5);
background: linear-gradient(135deg, var(--hover-color), var(--primary-color));
}
button:active {
transform: translateY(1px);
}
#audio-output {
display: none;
}
.icon-with-spinner {
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
min-width: 180px;
}
.spinner {
width: 20px;
height: 20px;
border: 2px solid #ffffff;
border-top-color: transparent;
border-radius: 50%;
animation: spin 1s linear infinite;
flex-shrink: 0;
}
@keyframes spin {
to {
transform: rotate(360deg);
}
}
.audio-visualizer {
display: flex;
align-items: center;
justify-content: center;
gap: 5px;
min-width: 80px;
height: 25px;
}
.visualizer-bar {
width: 4px;
height: 100%;
background-color: rgba(255, 255, 255, 0.7);
border-radius: 2px;
transform-origin: bottom;
transform: scaleY(0.1);
transition: transform 0.1s ease;
}
.toast {
position: fixed;
top: 20px;
left: 50%;
transform: translateX(-50%);
padding: 16px 24px;
border-radius: 8px;
font-size: 14px;
z-index: 1000;
display: none;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
.toast.error {
background-color: #f44336;
color: white;
}
.toast.warning {
background-color: #ff9800;
color: white;
}
.status-indicator {
display: inline-flex;
align-items: center;
margin-top: 10px;
font-size: 14px;
color: #aaa;
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
margin-right: 8px;
}
.status-dot.connected {
background-color: #4caf50;
}
.status-dot.disconnected {
background-color: #f44336;
}
.status-dot.connecting {
background-color: #ff9800;
animation: pulse 1.5s infinite;
}
@keyframes pulse {
0% {
opacity: 0.6;
}
50% {
opacity: 1;
}
100% {
opacity: 0.6;
}
}
.mouse-logo {
position: relative;
width: 40px;
height: 40px;
}
.mouse-ears {
position: absolute;
width: 15px;
height: 15px;
background-color: var(--primary-color);
border-radius: 50%;
}
.mouse-ear-left {
top: 0;
left: 5px;
}
.mouse-ear-right {
top: 0;
right: 5px;
}
.mouse-face {
position: absolute;
top: 10px;
left: 5px;
width: 30px;
height: 30px;
background-color: var(--secondary-color);
border-radius: 50%;
}
</style>
</head>
<body>
<div id="error-toast" class="toast"></div>
<div class="container">
<div class="header">
<div class="logo">
<div class="mouse-logo">
<div class="mouse-ears mouse-ear-left"></div>
<div class="mouse-ears mouse-ear-right"></div>
<div class="mouse-face"></div>
</div>
<h1>MOUSE 음성 챗</h1>
</div>
<div class="search-toggle">
<span class="search-label">웹 검색</span>
<div id="search-toggle" class="toggle-switch">
<div class="toggle-slider"></div>
</div>
</div>
<div class="status-indicator">
<div id="status-dot" class="status-dot disconnected"></div>
<span id="status-text">연결 대기 중</span>
</div>
</div>
<div class="chat-container">
<div class="chat-messages" id="chat-messages"></div>
</div>
<div class="controls">
<button id="start-button">대화 시작</button>
</div>
</div>
<audio id="audio-output"></audio>
<script>
let peerConnection;
let webrtc_id;
let webSearchEnabled = false;
const audioOutput = document.getElementById('audio-output');
const startButton = document.getElementById('start-button');
const chatMessages = document.getElementById('chat-messages');
const statusDot = document.getElementById('status-dot');
const statusText = document.getElementById('status-text');
const searchToggle = document.getElementById('search-toggle');
let audioLevel = 0;
let animationFrame;
let audioContext, analyser, audioSource;
// Web search toggle functionality
searchToggle.addEventListener('click', () => {
webSearchEnabled = !webSearchEnabled;
searchToggle.classList.toggle('active', webSearchEnabled);
console.log('Web search enabled:', webSearchEnabled);
});
function updateStatus(state) {
statusDot.className = 'status-dot ' + state;
if (state === 'connected') {
statusText.textContent = '연결됨';
} else if (state === 'connecting') {
statusText.textContent = '연결 중...';
} else {
statusText.textContent = '연결 대기 중';
}
}
function updateButtonState() {
const button = document.getElementById('start-button');
if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
button.innerHTML = `
<div class="icon-with-spinner">
<div class="spinner"></div>
<span>연결 중...</span>
</div>
`;
updateStatus('connecting');
} else if (peerConnection && peerConnection.connectionState === 'connected') {
button.innerHTML = `
<div class="icon-with-spinner">
<div class="audio-visualizer" id="audio-visualizer">
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
<div class="visualizer-bar"></div>
</div>
<span>대화 종료</span>
</div>
`;
updateStatus('connected');
} else {
button.innerHTML = '대화 시작';
updateStatus('disconnected');
}
}
function setupAudioVisualization(stream) {
audioContext = new (window.AudioContext || window.webkitAudioContext)();
analyser = audioContext.createAnalyser();
audioSource = audioContext.createMediaStreamSource(stream);
audioSource.connect(analyser);
analyser.fftSize = 256;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
const visualizerBars = document.querySelectorAll('.visualizer-bar');
const barCount = visualizerBars.length;
function updateAudioLevel() {
analyser.getByteFrequencyData(dataArray);
for (let i = 0; i < barCount; i++) {
const start = Math.floor(i * (bufferLength / barCount));
const end = Math.floor((i + 1) * (bufferLength / barCount));
let sum = 0;
for (let j = start; j < end; j++) {
sum += dataArray[j];
}
const average = sum / (end - start) / 255;
const scaleY = 0.1 + average * 0.9;
visualizerBars[i].style.transform = `scaleY(${scaleY})`;
}
animationFrame = requestAnimationFrame(updateAudioLevel);
}
updateAudioLevel();
}
function showError(message) {
const toast = document.getElementById('error-toast');
toast.textContent = message;
toast.className = 'toast error';
toast.style.display = 'block';
setTimeout(() => {
toast.style.display = 'none';
}, 5000);
}
async function setupWebRTC() {
const config = __RTC_CONFIGURATION__;
peerConnection = new RTCPeerConnection(config);
const timeoutId = setTimeout(() => {
const toast = document.getElementById('error-toast');
toast.textContent = "연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?";
toast.className = 'toast warning';
toast.style.display = 'block';
setTimeout(() => {
toast.style.display = 'none';
}, 5000);
}, 5000);
try {
const stream = await navigator.mediaDevices.getUserMedia({
audio: true
});
setupAudioVisualization(stream);
stream.getTracks().forEach(track => {
peerConnection.addTrack(track, stream);
});
peerConnection.addEventListener('track', (evt) => {
if (audioOutput.srcObject !== evt.streams[0]) {
audioOutput.srcObject = evt.streams[0];
audioOutput.play();
}
});
const dataChannel = peerConnection.createDataChannel('text');
dataChannel.onmessage = (event) => {
const eventJson = JSON.parse(event.data);
if (eventJson.type === "error") {
showError(eventJson.message);
}
};
const offer = await peerConnection.createOffer();
await peerConnection.setLocalDescription(offer);
await new Promise((resolve) => {
if (peerConnection.iceGatheringState === "complete") {
resolve();
} else {
const checkState = () => {
if (peerConnection.iceGatheringState === "complete") {
peerConnection.removeEventListener("icegatheringstatechange", checkState);
resolve();
}
};
peerConnection.addEventListener("icegatheringstatechange", checkState);
}
});
peerConnection.addEventListener('connectionstatechange', () => {
console.log('connectionstatechange', peerConnection.connectionState);
if (peerConnection.connectionState === 'connected') {
clearTimeout(timeoutId);
const toast = document.getElementById('error-toast');
toast.style.display = 'none';
}
updateButtonState();
});
webrtc_id = Math.random().toString(36).substring(7);
const response = await fetch('/webrtc/offer', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
sdp: peerConnection.localDescription.sdp,
type: peerConnection.localDescription.type,
webrtc_id: webrtc_id,
web_search_enabled: webSearchEnabled
})
});
const serverResponse = await response.json();
if (serverResponse.status === 'failed') {
showError(serverResponse.meta.error === 'concurrency_limit_reached'
? `너무 많은 연결입니다. 최대 한도는 ${serverResponse.meta.limit} 입니다.`
: serverResponse.meta.error);
stop();
return;
}
await peerConnection.setRemoteDescription(serverResponse);
const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
eventSource.addEventListener("output", (event) => {
const eventJson = JSON.parse(event.data);
addMessage("assistant", eventJson.content);
});
eventSource.addEventListener("search", (event) => {
const eventJson = JSON.parse(event.data);
if (eventJson.query) {
addMessage("search-result", `웹 검색 중: "${eventJson.query}"`);
}
});
} catch (err) {
clearTimeout(timeoutId);
console.error('Error setting up WebRTC:', err);
showError('연결을 설정하지 못했습니다. 다시 시도해 주세요.');
stop();
}
}
function addMessage(role, content) {
const messageDiv = document.createElement('div');
messageDiv.classList.add('message', role);
messageDiv.textContent = content;
chatMessages.appendChild(messageDiv);
chatMessages.scrollTop = chatMessages.scrollHeight;
}
function stop() {
if (animationFrame) {
cancelAnimationFrame(animationFrame);
}
if (audioContext) {
audioContext.close();
audioContext = null;
analyser = null;
audioSource = null;
}
if (peerConnection) {
if (peerConnection.getTransceivers) {
peerConnection.getTransceivers().forEach(transceiver => {
if (transceiver.stop) {
transceiver.stop();
}
});
}
if (peerConnection.getSenders) {
peerConnection.getSenders().forEach(sender => {
if (sender.track && sender.track.stop) sender.track.stop();
});
}
console.log('closing');
peerConnection.close();
}
updateButtonState();
audioLevel = 0;
}
startButton.addEventListener('click', () => {
console.log('clicked');
console.log(peerConnection, peerConnection?.connectionState);
if (!peerConnection || peerConnection.connectionState !== 'connected') {
setupWebRTC();
} else {
console.log('stopping');
stop();
}
});
</script>
</body>
</html>"""
class BraveSearchClient:
"""Brave Search API client"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.search.brave.com/res/v1/web/search"
async def search(self, query: str, count: int = 10) -> List[Dict]:
"""Perform a web search using Brave Search API"""
if not self.api_key:
return []
headers = {
"Accept": "application/json",
"X-Subscription-Token": self.api_key
}
params = {
"q": query,
"count": count,
"lang": "ko"
}
async with httpx.AsyncClient() as client:
try:
response = await client.get(self.base_url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
results = []
if "web" in data and "results" in data["web"]:
for result in data["web"]["results"][:count]:
results.append({
"title": result.get("title", ""),
"url": result.get("url", ""),
"description": result.get("description", "")
})
return results
except Exception as e:
print(f"Brave Search error: {e}")
return []
# Initialize search client globally
brave_api_key = os.getenv("BSEARCH_API")
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
# Store web search settings by connection
web_search_settings = {}
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
class OpenAIHandler(AsyncStreamHandler):
def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None:
super().__init__(
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.output_queue = asyncio.Queue()
self.search_client = search_client
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
self.webrtc_id = webrtc_id
self.web_search_enabled = web_search_enabled
print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
def copy(self):
# Get the most recent settings
if web_search_settings:
# Get the most recent webrtc_id
recent_ids = sorted(web_search_settings.keys(),
key=lambda k: web_search_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = web_search_settings[recent_id]
web_search_enabled = settings.get('enabled', False)
print(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}")
return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id)
print(f"Handler.copy() called - creating new handler with default settings")
return OpenAIHandler(web_search_enabled=False)
async def search_web(self, query: str) -> str:
"""Perform web search and return formatted results"""
if not self.search_client or not self.web_search_enabled:
return "웹 검색이 비활성화되어 있습니다."
print(f"Searching web for: {query}")
results = await self.search_client.search(query)
if not results:
return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
# Format search results
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(
f"{i}. {result['title']}\n"
f" URL: {result['url']}\n"
f" {result['description']}\n"
)
return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
async def start_up(self):
"""Connect to realtime API with function calling enabled"""
# First check if we have the most recent settings
if web_search_settings:
recent_ids = sorted(web_search_settings.keys(),
key=lambda k: web_search_settings[k].get('timestamp', 0),
reverse=True)
if recent_ids:
recent_id = recent_ids[0]
settings = web_search_settings[recent_id]
self.web_search_enabled = settings.get('enabled', False)
self.webrtc_id = recent_id
print(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}")
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
self.client = openai.AsyncOpenAI()
# Define the web search function
tools = []
instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean."
if self.web_search_enabled and self.search_client:
tools = [{
"type": "function",
"function": {
"name": "web_search",
"description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query in Korean or English"
}
},
"required": ["query"]
}
}
}]
print("Web search function added to tools")
instructions = (
"You are a helpful assistant with web search capabilities. "
"IMPORTANT: You MUST use the web_search function for ANY of these topics:\n"
"- Weather (날씨, 기온, 비, 눈)\n"
"- News (뉴스, 소식)\n"
"- Current events (현재, 최근, 오늘, 지금)\n"
"- Prices (가격, 환율, 주가)\n"
"- Sports scores or results\n"
"- Any question about 2024 or 2025\n"
"- Any time-sensitive information\n\n"
"When in doubt, USE web_search. It's better to search and provide accurate information "
"than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
)
async with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
# Update session with tools
session_update = {
"turn_detection": {"type": "server_vad"},
"instructions": instructions,
"tools": tools,
"tool_choice": "auto" if tools else "none"
}
await conn.session.update(session=session_update)
self.connection = conn
print(f"Connected with tools: {len(tools)} functions")
async for event in self.connection:
# Debug logging for function calls
if event.type.startswith("response.function_call"):
print(f"Function event: {event.type}")
if event.type == "response.audio_transcript.done":
await self.output_queue.put(AdditionalOutputs(event))
elif event.type == "response.audio.delta":
await self.output_queue.put(
(
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
),
)
# Handle function calls
elif event.type == "response.function_call_arguments.start":
print(f"Function call started")
self.function_call_in_progress = True
self.current_function_args = ""
self.current_call_id = getattr(event, 'call_id', None)
elif event.type == "response.function_call_arguments.delta":
if self.function_call_in_progress:
self.current_function_args += event.delta
elif event.type == "response.function_call_arguments.done":
if self.function_call_in_progress:
print(f"Function call done, args: {self.current_function_args}")
try:
args = json.loads(self.current_function_args)
query = args.get("query", "")
# Emit search event to client
await self.output_queue.put(AdditionalOutputs({
"type": "search",
"query": query
}))
# Perform the search
search_results = await self.search_web(query)
print(f"Search results length: {len(search_results)}")
# Send function result back to the model
if self.connection and self.current_call_id:
await self.connection.conversation.item.create(
item={
"type": "function_call_output",
"call_id": self.current_call_id,
"output": search_results
}
)
await self.connection.response.create()
except Exception as e:
print(f"Function call error: {e}")
finally:
self.function_call_in_progress = False
self.current_function_args = ""
self.current_call_id = None
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.connection:
return
try:
_, array = frame
array = array.squeeze()
audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
await self.connection.input_audio_buffer.append(audio=audio_message)
except Exception as e:
print(f"Error in receive: {e}")
# Connection might be closed, ignore the error
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
return await wait_for_item(self.output_queue)
async def shutdown(self) -> None:
if self.connection:
await self.connection.close()
self.connection = None
# Create initial handler instance
handler = OpenAIHandler(web_search_enabled=False)
# Create components
chatbot = gr.Chatbot(type="messages")
# Create stream with handler instance
stream = Stream(
handler, # Pass instance, not factory
mode="send-receive",
modality="audio",
additional_inputs=[chatbot],
additional_outputs=[chatbot],
additional_outputs_handler=update_chatbot,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
concurrency_limit=5 if get_space() else None,
time_limit=300 if get_space() else None,
)
app = FastAPI()
# Mount stream
stream.mount(app)
# Intercept offer to capture settings
@app.post("/webrtc/offer", include_in_schema=False)
async def custom_offer(request: Request):
"""Intercept offer to capture web search settings"""
body = await request.json()
webrtc_id = body.get("webrtc_id")
web_search_enabled = body.get("web_search_enabled", False)
print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
# Store settings with timestamp
if webrtc_id:
web_search_settings[webrtc_id] = {
'enabled': web_search_enabled,
'timestamp': asyncio.get_event_loop().time()
}
# Remove our custom route temporarily
custom_route = None
for i, route in enumerate(app.routes):
if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer:
custom_route = app.routes.pop(i)
break
# Forward to stream's offer handler
response = await stream.offer(body)
# Re-add our custom route
if custom_route:
app.routes.insert(0, custom_route)
return response
@app.get("/outputs")
async def outputs(webrtc_id: str):
"""Stream outputs including search events"""
async def output_stream():
async for output in stream.output_stream(webrtc_id):
if hasattr(output, 'args') and output.args:
# Check if it's a search event
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
# Regular transcript event
elif hasattr(output.args[0], 'transcript'):
s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
yield f"event: output\ndata: {s}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
@app.get("/")
async def index():
"""Serve the HTML page"""
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
if __name__ == "__main__":
import uvicorn
mode = os.getenv("MODE")
if mode == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
stream.fastphone(host="0.0.0.0", port=7860)
else:
uvicorn.run(app, host="0.0.0.0", port=7860)