Spaces:
Running
Running
import asyncio | |
import base64 | |
import json | |
from pathlib import Path | |
import os | |
import numpy as np | |
import openai | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, Request | |
from fastapi.responses import HTMLResponse, StreamingResponse | |
from fastrtc import ( | |
AdditionalOutputs, | |
AsyncStreamHandler, | |
Stream, | |
get_twilio_turn_credentials, | |
wait_for_item, | |
) | |
from gradio.utils import get_space | |
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent | |
import httpx | |
from typing import Optional, List, Dict | |
import gradio as gr | |
import io | |
from scipy import signal | |
import wave | |
import aiosqlite | |
from langdetect import detect, LangDetectException | |
from datetime import datetime | |
import uuid | |
load_dotenv() | |
SAMPLE_RATE = 24000 | |
# Use Persistent Storage path for Hugging Face Space | |
# In HF Spaces, persistent storage is at /data | |
if os.path.exists("/data"): | |
PERSISTENT_DIR = "/data" | |
else: | |
PERSISTENT_DIR = "./data" | |
os.makedirs(PERSISTENT_DIR, exist_ok=True) | |
DB_PATH = os.path.join(PERSISTENT_DIR, "personal_assistant.db") | |
print(f"Using persistent directory: {PERSISTENT_DIR}") | |
print(f"Database path: {DB_PATH}") | |
# HTML content embedded as a string | |
HTML_CONTENT = """<!DOCTYPE html> | |
<html lang="ko"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Personal AI Assistant</title> | |
<style> | |
:root { | |
--primary-color: #6f42c1; | |
--secondary-color: #563d7c; | |
--dark-bg: #121212; | |
--card-bg: #1e1e1e; | |
--text-color: #f8f9fa; | |
--border-color: #333; | |
--hover-color: #8a5cf6; | |
--memory-color: #4a9eff; | |
} | |
body { | |
font-family: "SF Pro Display", -apple-system, BlinkMacSystemFont, sans-serif; | |
background-color: var(--dark-bg); | |
color: var(--text-color); | |
margin: 0; | |
padding: 0; | |
height: 100vh; | |
display: flex; | |
flex-direction: column; | |
overflow: hidden; | |
} | |
.container { | |
max-width: 1400px; | |
margin: 0 auto; | |
padding: 20px; | |
flex-grow: 1; | |
display: flex; | |
flex-direction: column; | |
width: 100%; | |
height: 100vh; | |
box-sizing: border-box; | |
overflow: hidden; | |
} | |
.header { | |
text-align: center; | |
padding: 15px 0; | |
border-bottom: 1px solid var(--border-color); | |
margin-bottom: 20px; | |
flex-shrink: 0; | |
background-color: var(--card-bg); | |
} | |
.main-content { | |
display: flex; | |
gap: 20px; | |
flex-grow: 1; | |
min-height: 0; | |
overflow: hidden; | |
} | |
.sidebar { | |
width: 350px; | |
flex-shrink: 0; | |
display: flex; | |
flex-direction: column; | |
gap: 20px; | |
overflow-y: auto; | |
max-height: calc(100vh - 120px); | |
} | |
.chat-section { | |
flex-grow: 1; | |
display: flex; | |
flex-direction: column; | |
min-width: 0; | |
} | |
.logo { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 10px; | |
} | |
.logo h1 { | |
margin: 0; | |
background: linear-gradient(135deg, var(--primary-color), #a78bfa); | |
-webkit-background-clip: text; | |
background-clip: text; | |
color: transparent; | |
font-size: 32px; | |
letter-spacing: 1px; | |
} | |
/* Settings section */ | |
.settings-section { | |
background-color: var(--card-bg); | |
border-radius: 12px; | |
padding: 20px; | |
border: 1px solid var(--border-color); | |
overflow-y: auto; | |
} | |
.settings-grid { | |
display: flex; | |
flex-direction: column; | |
gap: 15px; | |
margin-bottom: 15px; | |
} | |
.setting-item { | |
display: flex; | |
align-items: center; | |
justify-content: space-between; | |
gap: 10px; | |
} | |
.setting-label { | |
font-size: 14px; | |
color: #aaa; | |
min-width: 60px; | |
} | |
/* Toggle switch */ | |
.toggle-switch { | |
position: relative; | |
width: 50px; | |
height: 26px; | |
background-color: #ccc; | |
border-radius: 13px; | |
cursor: pointer; | |
transition: background-color 0.3s; | |
} | |
.toggle-switch.active { | |
background-color: var(--primary-color); | |
} | |
.toggle-slider { | |
position: absolute; | |
top: 3px; | |
left: 3px; | |
width: 20px; | |
height: 20px; | |
background-color: white; | |
border-radius: 50%; | |
transition: transform 0.3s; | |
} | |
.toggle-switch.active .toggle-slider { | |
transform: translateX(24px); | |
} | |
/* Memory section */ | |
.memory-section { | |
background-color: var(--card-bg); | |
border-radius: 12px; | |
padding: 20px; | |
border: 1px solid var(--border-color); | |
max-height: 300px; | |
overflow-y: auto; | |
} | |
.memory-item { | |
padding: 10px; | |
margin-bottom: 10px; | |
background: linear-gradient(135deg, rgba(74, 158, 255, 0.1), rgba(111, 66, 193, 0.1)); | |
border-radius: 6px; | |
border-left: 3px solid var(--memory-color); | |
} | |
.memory-category { | |
font-size: 12px; | |
color: var(--memory-color); | |
font-weight: bold; | |
text-transform: uppercase; | |
margin-bottom: 5px; | |
} | |
.memory-content { | |
font-size: 14px; | |
color: var(--text-color); | |
} | |
/* History section */ | |
.history-section { | |
background-color: var(--card-bg); | |
border-radius: 12px; | |
padding: 20px; | |
border: 1px solid var(--border-color); | |
max-height: 200px; | |
overflow-y: auto; | |
} | |
.history-item { | |
padding: 10px; | |
margin-bottom: 10px; | |
background-color: var(--dark-bg); | |
border-radius: 6px; | |
cursor: pointer; | |
transition: background-color 0.2s; | |
} | |
.history-item:hover { | |
background-color: var(--hover-color); | |
} | |
.history-date { | |
font-size: 12px; | |
color: #888; | |
} | |
.history-preview { | |
font-size: 14px; | |
margin-top: 5px; | |
overflow: hidden; | |
text-overflow: ellipsis; | |
white-space: nowrap; | |
} | |
/* Text inputs */ | |
.text-input-section { | |
margin-top: 15px; | |
} | |
input[type="text"], textarea { | |
width: 100%; | |
background-color: var(--dark-bg); | |
color: var(--text-color); | |
border: 1px solid var(--border-color); | |
padding: 10px; | |
border-radius: 6px; | |
font-size: 14px; | |
box-sizing: border-box; | |
margin-top: 5px; | |
} | |
input[type="text"]:focus, textarea:focus { | |
outline: none; | |
border-color: var(--primary-color); | |
} | |
textarea { | |
resize: vertical; | |
min-height: 80px; | |
} | |
.chat-container { | |
border-radius: 12px; | |
background-color: var(--card-bg); | |
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); | |
padding: 20px; | |
flex-grow: 1; | |
display: flex; | |
flex-direction: column; | |
border: 1px solid var(--border-color); | |
overflow: hidden; | |
min-height: 0; | |
height: 100%; | |
} | |
.chat-messages { | |
flex-grow: 1; | |
overflow-y: auto; | |
padding: 15px; | |
scrollbar-width: thin; | |
scrollbar-color: var(--primary-color) var(--card-bg); | |
min-height: 0; | |
max-height: calc(100vh - 250px); | |
} | |
.chat-messages::-webkit-scrollbar { | |
width: 6px; | |
} | |
.chat-messages::-webkit-scrollbar-thumb { | |
background-color: var(--primary-color); | |
border-radius: 6px; | |
} | |
.message { | |
margin-bottom: 15px; | |
padding: 12px 16px; | |
border-radius: 8px; | |
font-size: 15px; | |
line-height: 1.5; | |
position: relative; | |
max-width: 85%; | |
animation: fade-in 0.3s ease-out; | |
word-wrap: break-word; | |
} | |
@keyframes fade-in { | |
from { | |
opacity: 0; | |
transform: translateY(10px); | |
} | |
to { | |
opacity: 1; | |
transform: translateY(0); | |
} | |
} | |
.message.user { | |
background: linear-gradient(135deg, #2c3e50, #34495e); | |
margin-left: auto; | |
border-bottom-right-radius: 2px; | |
} | |
.message.assistant { | |
background: linear-gradient(135deg, var(--secondary-color), var(--primary-color)); | |
margin-right: auto; | |
border-bottom-left-radius: 2px; | |
} | |
.message.search-result { | |
background: linear-gradient(135deg, #1a5a3e, #2e7d32); | |
font-size: 14px; | |
padding: 10px; | |
margin-bottom: 10px; | |
} | |
.message.memory-update { | |
background: linear-gradient(135deg, rgba(74, 158, 255, 0.2), rgba(111, 66, 193, 0.2)); | |
font-size: 13px; | |
padding: 8px 12px; | |
margin-bottom: 10px; | |
border-left: 3px solid var(--memory-color); | |
} | |
.language-info { | |
font-size: 12px; | |
color: #888; | |
margin-left: 5px; | |
} | |
.controls { | |
text-align: center; | |
margin-top: auto; | |
display: flex; | |
justify-content: center; | |
gap: 10px; | |
flex-shrink: 0; | |
padding-top: 20px; | |
} | |
/* Responsive design */ | |
@media (max-width: 1024px) { | |
.sidebar { | |
width: 300px; | |
} | |
} | |
@media (max-width: 768px) { | |
.main-content { | |
flex-direction: column; | |
} | |
.sidebar { | |
width: 100%; | |
margin-bottom: 20px; | |
} | |
.chat-section { | |
height: 400px; | |
} | |
} | |
button { | |
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); | |
color: white; | |
border: none; | |
padding: 14px 28px; | |
font-family: inherit; | |
font-size: 16px; | |
cursor: pointer; | |
transition: all 0.3s; | |
text-transform: uppercase; | |
letter-spacing: 1px; | |
border-radius: 50px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 10px; | |
box-shadow: 0 4px 10px rgba(111, 66, 193, 0.3); | |
} | |
button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 6px 15px rgba(111, 66, 193, 0.5); | |
background: linear-gradient(135deg, var(--hover-color), var(--primary-color)); | |
} | |
button:active { | |
transform: translateY(1px); | |
} | |
#send-button { | |
background: linear-gradient(135deg, #2ecc71, #27ae60); | |
padding: 10px 20px; | |
font-size: 14px; | |
flex-shrink: 0; | |
} | |
#send-button:hover { | |
background: linear-gradient(135deg, #27ae60, #229954); | |
} | |
#end-session-button { | |
background: linear-gradient(135deg, #4a9eff, #3a7ed8); | |
padding: 8px 16px; | |
font-size: 13px; | |
} | |
#end-session-button:hover { | |
background: linear-gradient(135deg, #3a7ed8, #2a5eb8); | |
} | |
#audio-output { | |
display: none; | |
} | |
.icon-with-spinner { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 12px; | |
min-width: 180px; | |
} | |
.spinner { | |
width: 20px; | |
height: 20px; | |
border: 2px solid #ffffff; | |
border-top-color: transparent; | |
border-radius: 50%; | |
animation: spin 1s linear infinite; | |
flex-shrink: 0; | |
} | |
@keyframes spin { | |
to { | |
transform: rotate(360deg); | |
} | |
} | |
.audio-visualizer { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 5px; | |
min-width: 80px; | |
height: 25px; | |
} | |
.visualizer-bar { | |
width: 4px; | |
height: 100%; | |
background-color: rgba(255, 255, 255, 0.7); | |
border-radius: 2px; | |
transform-origin: bottom; | |
transform: scaleY(0.1); | |
transition: transform 0.1s ease; | |
} | |
.toast { | |
position: fixed; | |
top: 20px; | |
left: 50%; | |
transform: translateX(-50%); | |
padding: 16px 24px; | |
border-radius: 8px; | |
font-size: 14px; | |
z-index: 1000; | |
display: none; | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); | |
} | |
.toast.error { | |
background-color: #f44336; | |
color: white; | |
} | |
.toast.warning { | |
background-color: #ff9800; | |
color: white; | |
} | |
.toast.success { | |
background-color: #4caf50; | |
color: white; | |
} | |
.status-indicator { | |
display: inline-flex; | |
align-items: center; | |
margin-top: 10px; | |
font-size: 14px; | |
color: #aaa; | |
} | |
.status-dot { | |
width: 8px; | |
height: 8px; | |
border-radius: 50%; | |
margin-right: 8px; | |
} | |
.status-dot.connected { | |
background-color: #4caf50; | |
} | |
.status-dot.disconnected { | |
background-color: #f44336; | |
} | |
.status-dot.connecting { | |
background-color: #ff9800; | |
animation: pulse 1.5s infinite; | |
} | |
@keyframes pulse { | |
0% { | |
opacity: 0.6; | |
} | |
50% { | |
opacity: 1; | |
} | |
100% { | |
opacity: 0.6; | |
} | |
} | |
.user-avatar { | |
width: 40px; | |
height: 40px; | |
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); | |
border-radius: 50%; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
font-size: 20px; | |
font-weight: bold; | |
color: white; | |
} | |
</style> | |
</head> | |
<body> | |
<div id="error-toast" class="toast"></div> | |
<div class="container"> | |
<div class="header"> | |
<div class="logo"> | |
<div class="user-avatar" id="user-avatar">👤</div> | |
<h1>Personal AI Assistant</h1> | |
</div> | |
<div class="status-indicator"> | |
<div id="status-dot" class="status-dot disconnected"></div> | |
<span id="status-text">연결 대기 중</span> | |
</div> | |
</div> | |
<div class="main-content"> | |
<div class="sidebar"> | |
<div class="settings-section"> | |
<h3 style="margin: 0 0 15px 0; color: var(--primary-color);">설정</h3> | |
<div class="settings-grid"> | |
<div class="setting-item"> | |
<span class="setting-label">웹 검색</span> | |
<div id="search-toggle" class="toggle-switch"> | |
<div class="toggle-slider"></div> | |
</div> | |
</div> | |
</div> | |
<div class="text-input-section"> | |
<label for="user-name" class="setting-label">사용자 이름:</label> | |
<input type="text" id="user-name" placeholder="이름을 입력하세요..." /> | |
</div> | |
</div> | |
<div class="memory-section"> | |
<h3 style="margin: 0 0 15px 0; color: var(--memory-color);">기억된 정보</h3> | |
<div id="memory-list"></div> | |
</div> | |
<div class="history-section"> | |
<h3 style="margin: 0 0 15px 0; color: var(--primary-color);">대화 기록</h3> | |
<div id="history-list"></div> | |
</div> | |
<div class="controls"> | |
<button id="start-button">대화 시작</button> | |
<button id="end-session-button" style="display: none;">기억 업데이트</button> | |
</div> | |
</div> | |
<div class="chat-section"> | |
<div class="chat-container"> | |
<h3 style="margin: 0 0 15px 0; color: var(--primary-color);">대화</h3> | |
<div class="chat-messages" id="chat-messages"></div> | |
<div class="text-input-section" style="margin-top: 10px;"> | |
<div style="display: flex; gap: 10px;"> | |
<input type="text" id="text-input" placeholder="텍스트 메시지를 입력하세요..." style="flex-grow: 1;" /> | |
<button id="send-button" style="display: none;">전송</button> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div> | |
<audio id="audio-output"></audio> | |
<script> | |
let peerConnection; | |
let webrtc_id; | |
let webSearchEnabled = false; | |
let currentSessionId = null; | |
let userName = localStorage.getItem('userName') || ''; | |
let userMemories = {}; | |
const audioOutput = document.getElementById('audio-output'); | |
const startButton = document.getElementById('start-button'); | |
const endSessionButton = document.getElementById('end-session-button'); | |
const sendButton = document.getElementById('send-button'); | |
const chatMessages = document.getElementById('chat-messages'); | |
const statusDot = document.getElementById('status-dot'); | |
const statusText = document.getElementById('status-text'); | |
const searchToggle = document.getElementById('search-toggle'); | |
const textInput = document.getElementById('text-input'); | |
const historyList = document.getElementById('history-list'); | |
const memoryList = document.getElementById('memory-list'); | |
const userNameInput = document.getElementById('user-name'); | |
const userAvatar = document.getElementById('user-avatar'); | |
let audioLevel = 0; | |
let animationFrame; | |
let audioContext, analyser, audioSource; | |
let dataChannel = null; | |
let isVoiceActive = false; | |
// Initialize user name | |
userNameInput.value = userName; | |
if (userName) { | |
userAvatar.textContent = userName.charAt(0).toUpperCase(); | |
} | |
userNameInput.addEventListener('input', () => { | |
userName = userNameInput.value; | |
localStorage.setItem('userName', userName); | |
if (userName) { | |
userAvatar.textContent = userName.charAt(0).toUpperCase(); | |
} else { | |
userAvatar.textContent = '👤'; | |
} | |
}); | |
// Start new session | |
async function startNewSession() { | |
const response = await fetch('/session/new', { method: 'POST' }); | |
const data = await response.json(); | |
currentSessionId = data.session_id; | |
console.log('New session started:', currentSessionId); | |
loadHistory(); | |
loadMemories(); | |
} | |
// Load memories | |
async function loadMemories() { | |
try { | |
const response = await fetch('/memory/all'); | |
const memories = await response.json(); | |
userMemories = {}; | |
memoryList.innerHTML = ''; | |
console.log('[LoadMemories] Loaded memories from DB:', memories); | |
memories.forEach(memory => { | |
if (!userMemories[memory.category]) { | |
userMemories[memory.category] = []; | |
} | |
userMemories[memory.category].push(memory.content); | |
const item = document.createElement('div'); | |
item.className = 'memory-item'; | |
item.innerHTML = ` | |
<div class="memory-category">${memory.category}</div> | |
<div class="memory-content">${memory.content}</div> | |
`; | |
memoryList.appendChild(item); | |
}); | |
console.log('[LoadMemories] Formatted memories:', userMemories); | |
console.log('[LoadMemories] Total categories:', Object.keys(userMemories).length); | |
console.log('[LoadMemories] Total items:', Object.values(userMemories).flat().length); | |
} catch (error) { | |
console.error('Failed to load memories:', error); | |
} | |
} | |
// End session and update memories | |
async function endSession() { | |
if (!currentSessionId) return; | |
try { | |
addMessage('memory-update', '대화 내용을 분석하여 기억을 업데이트하고 있습니다...'); | |
const response = await fetch('/session/end', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ session_id: currentSessionId }) | |
}); | |
const result = await response.json(); | |
if (result.status === 'ok') { | |
showToast('기억이 성공적으로 업데이트되었습니다.', 'success'); | |
loadMemories(); | |
startNewSession(); | |
} | |
} catch (error) { | |
console.error('Failed to end session:', error); | |
showError('기억 업데이트 중 오류가 발생했습니다.'); | |
} | |
} | |
// Load conversation history | |
async function loadHistory() { | |
try { | |
const response = await fetch('/history/recent'); | |
const conversations = await response.json(); | |
historyList.innerHTML = ''; | |
conversations.forEach(conv => { | |
const item = document.createElement('div'); | |
item.className = 'history-item'; | |
item.innerHTML = ` | |
<div class="history-date">${new Date(conv.created_at).toLocaleString()}</div> | |
<div class="history-preview">${conv.summary || '대화 시작'}</div> | |
`; | |
item.onclick = () => loadConversation(conv.id); | |
historyList.appendChild(item); | |
}); | |
} catch (error) { | |
console.error('Failed to load history:', error); | |
} | |
} | |
// Load specific conversation | |
async function loadConversation(sessionId) { | |
try { | |
const response = await fetch(`/history/${sessionId}`); | |
const messages = await response.json(); | |
chatMessages.innerHTML = ''; | |
messages.forEach(msg => { | |
addMessage(msg.role, msg.content, false); | |
}); | |
} catch (error) { | |
console.error('Failed to load conversation:', error); | |
} | |
} | |
// Web search toggle functionality | |
searchToggle.addEventListener('click', () => { | |
webSearchEnabled = !webSearchEnabled; | |
searchToggle.classList.toggle('active', webSearchEnabled); | |
console.log('Web search enabled:', webSearchEnabled); | |
}); | |
// Text input handling | |
textInput.addEventListener('keypress', (e) => { | |
if (e.key === 'Enter' && !e.shiftKey) { | |
e.preventDefault(); | |
sendTextMessage(); | |
} | |
}); | |
sendButton.addEventListener('click', sendTextMessage); | |
endSessionButton.addEventListener('click', endSession); | |
async function sendTextMessage() { | |
const message = textInput.value.trim(); | |
if (!message) return; | |
// Check for stop words | |
const stopWords = ["중단", "그만", "스톱", "stop", "닥쳐", "멈춰", "중지"]; | |
if (stopWords.some(word => message.toLowerCase().includes(word))) { | |
addMessage('assistant', '대화를 중단합니다.'); | |
return; | |
} | |
// Add user message to chat | |
addMessage('user', message); | |
textInput.value = ''; | |
// Show sending indicator | |
const typingIndicator = document.createElement('div'); | |
typingIndicator.classList.add('message', 'assistant'); | |
typingIndicator.textContent = '입력 중...'; | |
typingIndicator.id = 'typing-indicator'; | |
chatMessages.appendChild(typingIndicator); | |
chatMessages.scrollTop = chatMessages.scrollHeight; | |
try { | |
// Send to text chat endpoint | |
const response = await fetch('/chat/text', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ | |
message: message, | |
web_search_enabled: webSearchEnabled, | |
session_id: currentSessionId, | |
user_name: userName, | |
memories: userMemories | |
}) | |
}); | |
const data = await response.json(); | |
// Remove typing indicator | |
const indicator = document.getElementById('typing-indicator'); | |
if (indicator) indicator.remove(); | |
if (data.error) { | |
showError(data.error); | |
} else { | |
// Add assistant response | |
let content = data.response; | |
if (data.detected_language) { | |
content += ` <span class="language-info">[${data.detected_language}]</span>`; | |
} | |
addMessage('assistant', content); | |
} | |
} catch (error) { | |
console.error('Error sending text message:', error); | |
const indicator = document.getElementById('typing-indicator'); | |
if (indicator) indicator.remove(); | |
showError('메시지 전송 중 오류가 발생했습니다.'); | |
} | |
} | |
function updateStatus(state) { | |
statusDot.className = 'status-dot ' + state; | |
if (state === 'connected') { | |
statusText.textContent = '연결됨'; | |
sendButton.style.display = 'block'; | |
endSessionButton.style.display = 'block'; | |
isVoiceActive = true; | |
} else if (state === 'connecting') { | |
statusText.textContent = '연결 중...'; | |
sendButton.style.display = 'none'; | |
endSessionButton.style.display = 'none'; | |
} else { | |
statusText.textContent = '연결 대기 중'; | |
sendButton.style.display = 'block'; | |
endSessionButton.style.display = 'block'; | |
isVoiceActive = false; | |
} | |
} | |
function showToast(message, type = 'info') { | |
const toast = document.getElementById('error-toast'); | |
toast.textContent = message; | |
toast.className = `toast ${type}`; | |
toast.style.display = 'block'; | |
setTimeout(() => { | |
toast.style.display = 'none'; | |
}, 5000); | |
} | |
function showError(message) { | |
showToast(message, 'error'); | |
} | |
function updateButtonState() { | |
const button = document.getElementById('start-button'); | |
if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) { | |
button.innerHTML = ` | |
<div class="icon-with-spinner"> | |
<div class="spinner"></div> | |
<span>연결 중...</span> | |
</div> | |
`; | |
updateStatus('connecting'); | |
} else if (peerConnection && peerConnection.connectionState === 'connected') { | |
button.innerHTML = ` | |
<div class="icon-with-spinner"> | |
<div class="audio-visualizer" id="audio-visualizer"> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
<div class="visualizer-bar"></div> | |
</div> | |
<span>대화 종료</span> | |
</div> | |
`; | |
updateStatus('connected'); | |
} else { | |
button.innerHTML = '대화 시작'; | |
updateStatus('disconnected'); | |
} | |
} | |
function setupAudioVisualization(stream) { | |
audioContext = new (window.AudioContext || window.webkitAudioContext)(); | |
analyser = audioContext.createAnalyser(); | |
audioSource = audioContext.createMediaStreamSource(stream); | |
audioSource.connect(analyser); | |
analyser.fftSize = 256; | |
const bufferLength = analyser.frequencyBinCount; | |
const dataArray = new Uint8Array(bufferLength); | |
const visualizerBars = document.querySelectorAll('.visualizer-bar'); | |
const barCount = visualizerBars.length; | |
function updateAudioLevel() { | |
analyser.getByteFrequencyData(dataArray); | |
for (let i = 0; i < barCount; i++) { | |
const start = Math.floor(i * (bufferLength / barCount)); | |
const end = Math.floor((i + 1) * (bufferLength / barCount)); | |
let sum = 0; | |
for (let j = start; j < end; j++) { | |
sum += dataArray[j]; | |
} | |
const average = sum / (end - start) / 255; | |
const scaleY = 0.1 + average * 0.9; | |
visualizerBars[i].style.transform = `scaleY(${scaleY})`; | |
} | |
animationFrame = requestAnimationFrame(updateAudioLevel); | |
} | |
updateAudioLevel(); | |
} | |
async function setupWebRTC() { | |
// 메모리가 로드되지 않았다면 먼저 로드 | |
if (Object.keys(userMemories).length === 0) { | |
console.log('[WebRTC] No memories loaded, loading now...'); | |
await loadMemories(); | |
} | |
const config = __RTC_CONFIGURATION__; | |
peerConnection = new RTCPeerConnection(config); | |
const timeoutId = setTimeout(() => { | |
showToast("연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?", 'warning'); | |
}, 5000); | |
try { | |
const stream = await navigator.mediaDevices.getUserMedia({ | |
audio: true | |
}); | |
setupAudioVisualization(stream); | |
stream.getTracks().forEach(track => { | |
peerConnection.addTrack(track, stream); | |
}); | |
peerConnection.addEventListener('track', (evt) => { | |
if (audioOutput.srcObject !== evt.streams[0]) { | |
audioOutput.srcObject = evt.streams[0]; | |
audioOutput.play(); | |
} | |
}); | |
// Create data channel for text messages | |
dataChannel = peerConnection.createDataChannel('text'); | |
dataChannel.onopen = () => { | |
console.log('Data channel opened'); | |
}; | |
dataChannel.onmessage = (event) => { | |
const eventJson = JSON.parse(event.data); | |
if (eventJson.type === "error") { | |
showError(eventJson.message); | |
} | |
}; | |
const offer = await peerConnection.createOffer(); | |
await peerConnection.setLocalDescription(offer); | |
await new Promise((resolve) => { | |
if (peerConnection.iceGatheringState === "complete") { | |
resolve(); | |
} else { | |
const checkState = () => { | |
if (peerConnection.iceGatheringState === "complete") { | |
peerConnection.removeEventListener("icegatheringstatechange", checkState); | |
resolve(); | |
} | |
}; | |
peerConnection.addEventListener("icegatheringstatechange", checkState); | |
} | |
}); | |
peerConnection.addEventListener('connectionstatechange', () => { | |
console.log('connectionstatechange', peerConnection.connectionState); | |
if (peerConnection.connectionState === 'connected') { | |
clearTimeout(timeoutId); | |
const toast = document.getElementById('error-toast'); | |
toast.style.display = 'none'; | |
} | |
updateButtonState(); | |
}); | |
webrtc_id = Math.random().toString(36).substring(7); | |
console.log('[WebRTC] Sending offer with memories:', userMemories); | |
console.log('[WebRTC] Total memory items:', Object.values(userMemories).flat().length); | |
const response = await fetch('/webrtc/offer', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ | |
sdp: peerConnection.localDescription.sdp, | |
type: peerConnection.localDescription.type, | |
webrtc_id: webrtc_id, | |
web_search_enabled: webSearchEnabled, | |
session_id: currentSessionId, | |
user_name: userName, | |
memories: userMemories | |
}) | |
}); | |
const serverResponse = await response.json(); | |
if (serverResponse.status === 'failed') { | |
showError(serverResponse.meta.error === 'concurrency_limit_reached' | |
? `너무 많은 연결입니다. 최대 한도는 ${serverResponse.meta.limit} 입니다.` | |
: serverResponse.meta.error); | |
stop(); | |
return; | |
} | |
await peerConnection.setRemoteDescription(serverResponse); | |
const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id); | |
eventSource.addEventListener("output", (event) => { | |
const eventJson = JSON.parse(event.data); | |
let content = eventJson.content; | |
if (eventJson.detected_language) { | |
content += ` <span class="language-info">[${eventJson.detected_language}]</span>`; | |
} | |
addMessage("assistant", content); | |
}); | |
eventSource.addEventListener("search", (event) => { | |
const eventJson = JSON.parse(event.data); | |
if (eventJson.query) { | |
addMessage("search-result", `웹 검색 중: "${eventJson.query}"`); | |
} | |
}); | |
} catch (err) { | |
clearTimeout(timeoutId); | |
console.error('Error setting up WebRTC:', err); | |
showError('연결을 설정하지 못했습니다. 다시 시도해 주세요.'); | |
stop(); | |
} | |
} | |
function addMessage(role, content, save = true) { | |
const messageDiv = document.createElement('div'); | |
messageDiv.classList.add('message', role); | |
if (content.includes('<span')) { | |
messageDiv.innerHTML = content; | |
} else { | |
messageDiv.textContent = content; | |
} | |
chatMessages.appendChild(messageDiv); | |
chatMessages.scrollTop = chatMessages.scrollHeight; | |
// Save message to database if save flag is true | |
if (save && currentSessionId && role !== 'memory-update' && role !== 'search-result') { | |
fetch('/message/save', { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify({ | |
session_id: currentSessionId, | |
role: role, | |
content: content | |
}) | |
}).catch(error => console.error('Failed to save message:', error)); | |
} | |
} | |
function stop() { | |
console.log('[STOP] Stopping connection...'); | |
// Cancel animation frame first | |
if (animationFrame) { | |
cancelAnimationFrame(animationFrame); | |
animationFrame = null; | |
} | |
// Close audio context | |
if (audioContext) { | |
audioContext.close(); | |
audioContext = null; | |
analyser = null; | |
audioSource = null; | |
} | |
// Close data channel | |
if (dataChannel) { | |
dataChannel.close(); | |
dataChannel = null; | |
} | |
// Close peer connection | |
if (peerConnection) { | |
console.log('[STOP] Current connection state:', peerConnection.connectionState); | |
// Stop all transceivers | |
if (peerConnection.getTransceivers) { | |
peerConnection.getTransceivers().forEach(transceiver => { | |
if (transceiver.stop) { | |
transceiver.stop(); | |
} | |
}); | |
} | |
// Stop all senders | |
if (peerConnection.getSenders) { | |
peerConnection.getSenders().forEach(sender => { | |
if (sender.track) { | |
sender.track.stop(); | |
} | |
}); | |
} | |
// Stop all receivers | |
if (peerConnection.getReceivers) { | |
peerConnection.getReceivers().forEach(receiver => { | |
if (receiver.track) { | |
receiver.track.stop(); | |
} | |
}); | |
} | |
// Close the connection | |
peerConnection.close(); | |
// Clear the reference | |
peerConnection = null; | |
console.log('[STOP] Connection closed'); | |
} | |
// Reset audio level | |
audioLevel = 0; | |
isVoiceActive = false; | |
// Update UI | |
updateButtonState(); | |
// Clear any existing webrtc_id | |
if (webrtc_id) { | |
console.log('[STOP] Clearing webrtc_id:', webrtc_id); | |
webrtc_id = null; | |
} | |
} | |
startButton.addEventListener('click', async () => { | |
console.log('clicked'); | |
console.log(peerConnection, peerConnection?.connectionState); | |
// 메모리가 로드되지 않았다면 먼저 로드 | |
if (Object.keys(userMemories).length === 0) { | |
console.log('[StartButton] Loading memories before starting...'); | |
await loadMemories(); | |
} | |
if (!peerConnection || peerConnection.connectionState !== 'connected') { | |
setupWebRTC(); | |
} else { | |
console.log('stopping'); | |
stop(); | |
} | |
}); | |
// Initialize on page load | |
window.addEventListener('DOMContentLoaded', () => { | |
sendButton.style.display = 'block'; | |
endSessionButton.style.display = 'block'; | |
startNewSession(); | |
loadHistory(); | |
loadMemories(); | |
}); | |
</script> | |
</body> | |
</html>""" | |
class BraveSearchClient: | |
"""Brave Search API client""" | |
def __init__(self, api_key: str): | |
self.api_key = api_key | |
self.base_url = "https://api.search.brave.com/res/v1/web/search" | |
async def search(self, query: str, count: int = 10) -> List[Dict]: | |
"""Perform a web search using Brave Search API""" | |
if not self.api_key: | |
return [] | |
headers = { | |
"Accept": "application/json", | |
"X-Subscription-Token": self.api_key | |
} | |
params = { | |
"q": query, | |
"count": count, | |
"lang": "ko" | |
} | |
async with httpx.AsyncClient() as client: | |
try: | |
response = await client.get(self.base_url, headers=headers, params=params) | |
response.raise_for_status() | |
data = response.json() | |
results = [] | |
if "web" in data and "results" in data["web"]: | |
for result in data["web"]["results"][:count]: | |
results.append({ | |
"title": result.get("title", ""), | |
"url": result.get("url", ""), | |
"description": result.get("description", "") | |
}) | |
return results | |
except Exception as e: | |
print(f"Brave Search error: {e}") | |
return [] | |
# Database helper class | |
class PersonalAssistantDB: | |
"""Database manager for personal assistant""" | |
async def init(): | |
"""Initialize database tables""" | |
async with aiosqlite.connect(DB_PATH) as db: | |
# Conversations table | |
await db.execute(""" | |
CREATE TABLE IF NOT EXISTS conversations ( | |
id TEXT PRIMARY KEY, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
summary TEXT | |
) | |
""") | |
# Messages table | |
await db.execute(""" | |
CREATE TABLE IF NOT EXISTS messages ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
session_id TEXT NOT NULL, | |
role TEXT NOT NULL, | |
content TEXT NOT NULL, | |
detected_language TEXT, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
FOREIGN KEY (session_id) REFERENCES conversations(id) | |
) | |
""") | |
# User memories table - stores personal information | |
await db.execute(""" | |
CREATE TABLE IF NOT EXISTS user_memories ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
category TEXT NOT NULL, | |
content TEXT NOT NULL, | |
confidence REAL DEFAULT 1.0, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
source_session_id TEXT, | |
FOREIGN KEY (source_session_id) REFERENCES conversations(id) | |
) | |
""") | |
# Create indexes for better performance | |
await db.execute("CREATE INDEX IF NOT EXISTS idx_memories_category ON user_memories(category)") | |
await db.execute("CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id)") | |
await db.commit() | |
async def create_session(session_id: str): | |
"""Create a new conversation session""" | |
async with aiosqlite.connect(DB_PATH) as db: | |
await db.execute( | |
"INSERT INTO conversations (id) VALUES (?)", | |
(session_id,) | |
) | |
await db.commit() | |
async def save_message(session_id: str, role: str, content: str): | |
"""Save a message to the database""" | |
# Check for None or empty content | |
if not content: | |
print(f"[SAVE_MESSAGE] Empty content for {role} message, skipping") | |
return | |
# Detect language | |
detected_language = None | |
try: | |
if content and len(content) > 10: | |
detected_language = detect(content) | |
except (LangDetectException, Exception) as e: | |
print(f"Language detection error: {e}") | |
async with aiosqlite.connect(DB_PATH) as db: | |
await db.execute( | |
"""INSERT INTO messages (session_id, role, content, detected_language) | |
VALUES (?, ?, ?, ?)""", | |
(session_id, role, content, detected_language) | |
) | |
# Update conversation's updated_at timestamp | |
await db.execute( | |
"UPDATE conversations SET updated_at = CURRENT_TIMESTAMP WHERE id = ?", | |
(session_id,) | |
) | |
# Update conversation summary (use first user message as summary) | |
if role == "user": | |
cursor = await db.execute( | |
"SELECT summary FROM conversations WHERE id = ?", | |
(session_id,) | |
) | |
row = await cursor.fetchone() | |
if row and not row[0]: | |
summary = content[:100] + "..." if len(content) > 100 else content | |
await db.execute( | |
"UPDATE conversations SET summary = ? WHERE id = ?", | |
(summary, session_id) | |
) | |
await db.commit() | |
async def get_recent_conversations(limit: int = 10): | |
"""Get recent conversations""" | |
async with aiosqlite.connect(DB_PATH) as db: | |
cursor = await db.execute( | |
"""SELECT id, created_at, summary | |
FROM conversations | |
ORDER BY updated_at DESC | |
LIMIT ?""", | |
(limit,) | |
) | |
rows = await cursor.fetchall() | |
return [ | |
{ | |
"id": row[0], | |
"created_at": row[1], | |
"summary": row[2] or "새 대화" | |
} | |
for row in rows | |
] | |
async def get_conversation_messages(session_id: str): | |
"""Get all messages for a conversation""" | |
async with aiosqlite.connect(DB_PATH) as db: | |
cursor = await db.execute( | |
"""SELECT role, content, detected_language, timestamp | |
FROM messages | |
WHERE session_id = ? | |
ORDER BY timestamp ASC""", | |
(session_id,) | |
) | |
rows = await cursor.fetchall() | |
return [ | |
{ | |
"role": row[0], | |
"content": row[1], | |
"detected_language": row[2], | |
"timestamp": row[3] | |
} | |
for row in rows | |
] | |
async def save_memory(category: str, content: str, session_id: str = None, confidence: float = 1.0): | |
"""Save or update a user memory""" | |
async with aiosqlite.connect(DB_PATH) as db: | |
# Check if similar memory exists | |
cursor = await db.execute( | |
"""SELECT id, content FROM user_memories | |
WHERE category = ? AND content LIKE ? | |
LIMIT 1""", | |
(category, f"%{content[:20]}%") | |
) | |
existing = await cursor.fetchone() | |
if existing: | |
# Update existing memory | |
await db.execute( | |
"""UPDATE user_memories | |
SET content = ?, confidence = ?, updated_at = CURRENT_TIMESTAMP, | |
source_session_id = ? | |
WHERE id = ?""", | |
(content, confidence, session_id, existing[0]) | |
) | |
else: | |
# Insert new memory | |
await db.execute( | |
"""INSERT INTO user_memories (category, content, confidence, source_session_id) | |
VALUES (?, ?, ?, ?)""", | |
(category, content, confidence, session_id) | |
) | |
await db.commit() | |
async def get_all_memories(): | |
"""Get all user memories""" | |
async with aiosqlite.connect(DB_PATH) as db: | |
cursor = await db.execute( | |
"""SELECT category, content, confidence, updated_at | |
FROM user_memories | |
ORDER BY category, updated_at DESC""" | |
) | |
rows = await cursor.fetchall() | |
return [ | |
{ | |
"category": row[0], | |
"content": row[1], | |
"confidence": row[2], | |
"updated_at": row[3] | |
} | |
for row in rows | |
] | |
async def extract_and_save_memories(session_id: str): | |
"""Extract memories from conversation and save them""" | |
# Get all messages from the session | |
messages = await PersonalAssistantDB.get_conversation_messages(session_id) | |
if not messages: | |
return | |
# Prepare conversation text for analysis | |
conversation_text = "\n".join([ | |
f"{msg['role']}: {msg['content']}" | |
for msg in messages if msg.get('content') | |
]) | |
# Use GPT to extract memories | |
client = openai.AsyncOpenAI() | |
try: | |
response = await client.chat.completions.create( | |
model="gpt-4.1-mini", | |
messages=[ | |
{ | |
"role": "system", | |
"content": """You are a memory extraction system. Extract personal information from conversations. | |
Categories to extract: | |
- personal_info: 이름, 나이, 성별, 직업, 거주지 | |
- preferences: 좋아하는 것, 싫어하는 것, 취향 | |
- important_dates: 생일, 기념일, 중요한 날짜 | |
- relationships: 가족, 친구, 동료 관계 | |
- hobbies: 취미, 관심사 | |
- health: 건강 상태, 알레르기, 의료 정보 | |
- goals: 목표, 계획, 꿈 | |
- routines: 일상, 습관, 루틴 | |
- work: 직장, 업무, 프로젝트 | |
- education: 학력, 전공, 학습 | |
Return as JSON array with format: | |
[ | |
{ | |
"category": "category_name", | |
"content": "extracted information in Korean", | |
"confidence": 0.0-1.0 | |
} | |
] | |
Only extract clear, factual information. Do not make assumptions.""" | |
}, | |
{ | |
"role": "user", | |
"content": f"Extract memories from this conversation:\n\n{conversation_text}" | |
} | |
], | |
temperature=0.3, | |
max_tokens=2000 | |
) | |
# Parse and save memories | |
memories_text = response.choices[0].message.content | |
# Extract JSON from response | |
import re | |
json_match = re.search(r'\[.*\]', memories_text, re.DOTALL) | |
if json_match: | |
memories = json.loads(json_match.group()) | |
for memory in memories: | |
if memory.get('content') and len(memory['content']) > 5: | |
await PersonalAssistantDB.save_memory( | |
category=memory.get('category', 'general'), | |
content=memory['content'], | |
session_id=session_id, | |
confidence=memory.get('confidence', 0.8) | |
) | |
print(f"Extracted and saved {len(memories)} memories from session {session_id}") | |
except Exception as e: | |
print(f"Error extracting memories: {e}") | |
# Initialize search client globally | |
brave_api_key = os.getenv("BSEARCH_API") | |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None | |
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}") | |
# Store connection settings | |
connection_settings = {} | |
# Initialize OpenAI client for text chat | |
client = openai.AsyncOpenAI() | |
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): | |
chatbot.append({"role": "assistant", "content": response.transcript}) | |
return chatbot | |
def format_memories_for_prompt(memories: Dict[str, List[str]]) -> str: | |
"""Format memories for inclusion in system prompt""" | |
if not memories: | |
return "" | |
memory_text = "\n\n=== 기억된 정보 ===\n" | |
memory_count = 0 | |
for category, items in memories.items(): | |
if items and isinstance(items, list): | |
valid_items = [item for item in items if item] # None이나 빈 문자열 제외 | |
if valid_items: | |
memory_text += f"\n[{category}]\n" | |
for item in valid_items: | |
memory_text += f"- {item}\n" | |
memory_count += 1 | |
print(f"[FORMAT_MEMORIES] Formatted {memory_count} memory items") | |
return memory_text if memory_count > 0 else "" | |
async def process_text_chat(message: str, web_search_enabled: bool, session_id: str, | |
user_name: str = "", memories: Dict = None) -> Dict[str, str]: | |
"""Process text chat using GPT-4o-mini model""" | |
try: | |
# Check for empty or None message | |
if not message: | |
return {"error": "메시지가 비어있습니다."} | |
# Check for stop words | |
stop_words = ["중단", "그만", "스톱", "stop", "닥쳐", "멈춰", "중지"] | |
if any(word in message.lower() for word in stop_words): | |
return { | |
"response": "대화를 중단합니다.", | |
"detected_language": "ko" | |
} | |
# Build system prompt with memories | |
base_prompt = f"""You are a personal AI assistant for {user_name if user_name else 'the user'}. | |
You remember all previous conversations and personal information about the user. | |
Be friendly, helpful, and personalized in your responses. | |
Always use the information you remember to make conversations more personal and relevant. | |
IMPORTANT: Give only ONE response. Do not repeat or give multiple answers.""" | |
# Add memories to prompt | |
if memories: | |
memory_text = format_memories_for_prompt(memories) | |
base_prompt += memory_text | |
messages = [{"role": "system", "content": base_prompt}] | |
# Handle web search if enabled | |
if web_search_enabled and search_client and message: | |
search_keywords = ["날씨", "기온", "비", "눈", "뉴스", "소식", "현재", "최근", | |
"오늘", "지금", "가격", "환율", "주가", "weather", "news", | |
"current", "today", "price", "2024", "2025"] | |
should_search = any(keyword in message.lower() for keyword in search_keywords) | |
if should_search: | |
search_results = await search_client.search(message) | |
if search_results: | |
search_context = "웹 검색 결과:\n\n" | |
for i, result in enumerate(search_results[:5], 1): | |
search_context += f"{i}. {result['title']}\n{result['description']}\n\n" | |
messages.append({ | |
"role": "system", | |
"content": "다음 웹 검색 결과를 참고하여 답변하세요:\n\n" + search_context | |
}) | |
messages.append({"role": "user", "content": message}) | |
# Call GPT-4o-mini | |
response = await client.chat.completions.create( | |
model="gpt-4.1-mini", | |
messages=messages, | |
temperature=0.7, | |
max_tokens=2000 | |
) | |
response_text = response.choices[0].message.content | |
# Detect language | |
detected_language = None | |
try: | |
if response_text and len(response_text) > 10: | |
detected_language = detect(response_text) | |
except: | |
pass | |
# Save messages to database | |
if session_id: | |
await PersonalAssistantDB.save_message(session_id, "user", message) | |
await PersonalAssistantDB.save_message(session_id, "assistant", response_text) | |
return { | |
"response": response_text, | |
"detected_language": detected_language | |
} | |
except Exception as e: | |
print(f"Error in text chat: {e}") | |
return {"error": str(e)} | |
class OpenAIHandler(AsyncStreamHandler): | |
def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None, | |
session_id: str = None, user_name: str = "", memories: Dict = None) -> None: | |
super().__init__( | |
expected_layout="mono", | |
output_sample_rate=SAMPLE_RATE, | |
output_frame_size=480, | |
input_sample_rate=SAMPLE_RATE, | |
) | |
self.connection = None | |
self.output_queue = asyncio.Queue() | |
self.search_client = search_client | |
self.function_call_in_progress = False | |
self.current_function_args = "" | |
self.current_call_id = None | |
self.webrtc_id = webrtc_id | |
self.web_search_enabled = web_search_enabled | |
self.session_id = session_id | |
self.user_name = user_name | |
self.memories = memories or {} | |
self.is_responding = False | |
self.should_stop = False | |
# 메모리 정보 로깅 | |
memory_count = sum(len(items) for items in self.memories.values() if isinstance(items, list)) | |
print(f"[INIT] Handler created with:") | |
print(f" - web_search={web_search_enabled}") | |
print(f" - session_id={session_id}") | |
print(f" - user={user_name}") | |
print(f" - memory categories={list(self.memories.keys())}") | |
print(f" - total memory items={memory_count}") | |
def copy(self): | |
# 가장 최근의 connection settings 가져오기 | |
if connection_settings: | |
recent_ids = sorted(connection_settings.keys(), | |
key=lambda k: connection_settings[k].get('timestamp', 0), | |
reverse=True) | |
if recent_ids: | |
recent_id = recent_ids[0] | |
settings = connection_settings[recent_id] | |
print(f"[COPY] Copying settings from {recent_id}:") | |
print(f"[COPY] - web_search: {settings.get('web_search_enabled', False)}") | |
print(f"[COPY] - session_id: {settings.get('session_id')}") | |
print(f"[COPY] - user_name: {settings.get('user_name', '')}") | |
memories = settings.get('memories', {}) | |
# 메모리가 없으면 DB에서 직접 로드 (동기적으로) | |
if not memories: | |
print(f"[COPY] No memories in settings, loading from DB...") | |
import asyncio | |
try: | |
# 현재 이벤트 루프가 있는지 확인 | |
loop = asyncio.get_event_loop() | |
if loop.is_running(): | |
# 이미 실행 중인 루프가 있으면 run_in_executor 사용 | |
import concurrent.futures | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future = executor.submit(self._load_memories_sync) | |
memories_list = future.result() | |
else: | |
# 새 루프에서 실행 | |
memories_list = loop.run_until_complete(PersonalAssistantDB.get_all_memories()) | |
except: | |
# 새 루프 생성 | |
new_loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(new_loop) | |
memories_list = new_loop.run_until_complete(PersonalAssistantDB.get_all_memories()) | |
new_loop.close() | |
# 메모리를 카테고리별로 그룹화 | |
for memory in memories_list: | |
category = memory['category'] | |
if category not in memories: | |
memories[category] = [] | |
memories[category].append(memory['content']) | |
print(f"[COPY] Loaded {len(memories_list)} memories from DB") | |
print(f"[COPY] - memories count: {sum(len(items) for items in memories.values() if isinstance(items, list))}") | |
return OpenAIHandler( | |
web_search_enabled=settings.get('web_search_enabled', False), | |
webrtc_id=recent_id, | |
session_id=settings.get('session_id'), | |
user_name=settings.get('user_name', ''), | |
memories=memories | |
) | |
print(f"[COPY] No settings found, creating default handler") | |
return OpenAIHandler(web_search_enabled=False) | |
def _load_memories_sync(self): | |
"""동기적으로 메모리 로드 (Thread에서 실행용)""" | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
result = loop.run_until_complete(PersonalAssistantDB.get_all_memories()) | |
loop.close() | |
return result | |
async def search_web(self, query: str) -> str: | |
"""Perform web search and return formatted results""" | |
if not self.search_client or not self.web_search_enabled: | |
return "웹 검색이 비활성화되어 있습니다." | |
print(f"Searching web for: {query}") | |
results = await self.search_client.search(query) | |
if not results: | |
return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다." | |
formatted_results = [] | |
for i, result in enumerate(results, 1): | |
formatted_results.append( | |
f"{i}. {result['title']}\n" | |
f" URL: {result['url']}\n" | |
f" {result['description']}\n" | |
) | |
return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results) | |
async def process_text_message(self, message: str): | |
"""Process text message from user""" | |
if self.connection: | |
await self.connection.conversation.item.create( | |
item={ | |
"type": "message", | |
"role": "user", | |
"content": [{"type": "input_text", "text": message}] | |
} | |
) | |
await self.connection.response.create() | |
async def start_up(self): | |
"""Connect to realtime API""" | |
if connection_settings and self.webrtc_id: | |
if self.webrtc_id in connection_settings: | |
settings = connection_settings[self.webrtc_id] | |
self.web_search_enabled = settings.get('web_search_enabled', False) | |
self.session_id = settings.get('session_id') | |
self.user_name = settings.get('user_name', '') | |
self.memories = settings.get('memories', {}) | |
print(f"[START_UP] Updated settings from storage for {self.webrtc_id}") | |
# 메모리가 비어있고 session_id가 있으면 DB에서 로드 | |
if not self.memories: | |
print(f"[START_UP] No memories found, loading from DB...") | |
memories_list = await PersonalAssistantDB.get_all_memories() | |
# 메모리를 카테고리별로 그룹화 | |
self.memories = {} | |
for memory in memories_list: | |
category = memory['category'] | |
if category not in self.memories: | |
self.memories[category] = [] | |
self.memories[category].append(memory['content']) | |
print(f"[START_UP] Loaded {len(memories_list)} memories from DB") | |
print(f"[START_UP] Final memory count: {sum(len(items) for items in self.memories.values() if isinstance(items, list))}") | |
self.client = openai.AsyncOpenAI() | |
print(f"[REALTIME API] Connecting...") | |
# Build system prompt with memories | |
base_instructions = f"""You are a personal AI assistant for {self.user_name if self.user_name else 'the user'}. | |
You remember all previous conversations and personal information about the user. | |
Be friendly, helpful, and personalized in your responses. | |
Always use the information you remember to make conversations more personal and relevant. | |
IMPORTANT: Give only ONE response per user input. Do not repeat yourself or give multiple answers.""" | |
# Add memories to prompt | |
if self.memories: | |
memory_text = format_memories_for_prompt(self.memories) | |
base_instructions += memory_text | |
print(f"[START_UP] Added memories to system prompt: {len(memory_text)} characters") | |
else: | |
print(f"[START_UP] No memories to add to system prompt") | |
# Define the web search function | |
tools = [] | |
if self.web_search_enabled and self.search_client: | |
tools = [{ | |
"type": "function", | |
"function": { | |
"name": "web_search", | |
"description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "The search query" | |
} | |
}, | |
"required": ["query"] | |
} | |
} | |
}] | |
search_instructions = ( | |
"\n\nYou have web search capabilities. " | |
"Use web_search for current information like weather, news, prices, etc." | |
) | |
instructions = base_instructions + search_instructions | |
else: | |
instructions = base_instructions | |
async with self.client.beta.realtime.connect( | |
model="gpt-4o-mini-realtime-preview-2024-12-17" | |
) as conn: | |
session_update = { | |
"turn_detection": { | |
"type": "server_vad", | |
"threshold": 0.5, | |
"prefix_padding_ms": 300, | |
"silence_duration_ms": 200 | |
}, | |
"instructions": instructions, | |
"tools": tools, | |
"tool_choice": "auto" if tools else "none", | |
"temperature": 0.7, | |
"max_response_output_tokens": 4096, | |
"modalities": ["text", "audio"], | |
"voice": "alloy" | |
} | |
try: | |
await conn.session.update(session=session_update) | |
self.connection = conn | |
print(f"Connected with tools: {len(tools)} functions") | |
print(f"Session update successful") | |
except Exception as e: | |
print(f"Error updating session: {e}") | |
raise | |
async for event in self.connection: | |
# Debug log for all events | |
if hasattr(event, 'type'): | |
if event.type not in ["response.audio.delta", "response.audio.done"]: | |
print(f"[EVENT] Type: {event.type}") | |
# Handle user input audio transcription | |
if event.type == "conversation.item.input_audio_transcription.completed": | |
if hasattr(event, 'transcript') and event.transcript: | |
user_text = event.transcript.lower() | |
stop_words = ["중단", "그만", "스톱", "stop", "닥쳐", "멈춰", "중지"] | |
if any(word in user_text for word in stop_words): | |
print(f"[STOP DETECTED] User said: {event.transcript}") | |
self.should_stop = True | |
if self.connection: | |
try: | |
await self.connection.response.cancel() | |
except: | |
pass | |
continue | |
# Save user message to database | |
if self.session_id: | |
await PersonalAssistantDB.save_message(self.session_id, "user", event.transcript) | |
# Handle user transcription for stop detection (alternative event) | |
elif event.type == "conversation.item.created": | |
if hasattr(event, 'item') and hasattr(event.item, 'role') and event.item.role == "user": | |
if hasattr(event.item, 'content') and event.item.content: | |
for content_item in event.item.content: | |
if hasattr(content_item, 'transcript') and content_item.transcript: | |
user_text = content_item.transcript.lower() | |
stop_words = ["중단", "그만", "스톱", "stop", "닥쳐", "멈춰", "중지"] | |
if any(word in user_text for word in stop_words): | |
print(f"[STOP DETECTED] User said: {content_item.transcript}") | |
self.should_stop = True | |
if self.connection: | |
try: | |
await self.connection.response.cancel() | |
except: | |
pass | |
continue | |
# Save user message to database | |
if self.session_id: | |
await PersonalAssistantDB.save_message(self.session_id, "user", content_item.transcript) | |
elif event.type == "response.audio_transcript.done": | |
# Prevent multiple responses | |
if self.is_responding: | |
print("[DUPLICATE RESPONSE] Skipping duplicate response") | |
continue | |
self.is_responding = True | |
print(f"[RESPONSE] Transcript: {event.transcript[:100] if event.transcript else 'None'}...") | |
# Detect language | |
detected_language = None | |
try: | |
if event.transcript and len(event.transcript) > 10: | |
detected_language = detect(event.transcript) | |
except Exception as e: | |
print(f"Language detection error: {e}") | |
# Save to database | |
if self.session_id and event.transcript: | |
await PersonalAssistantDB.save_message(self.session_id, "assistant", event.transcript) | |
output_data = { | |
"event": event, | |
"detected_language": detected_language | |
} | |
await self.output_queue.put(AdditionalOutputs(output_data)) | |
elif event.type == "response.done": | |
# Reset responding flag when response is complete | |
self.is_responding = False | |
self.should_stop = False | |
print("[RESPONSE DONE] Response completed") | |
elif event.type == "response.audio.delta": | |
# Check if we should stop | |
if self.should_stop: | |
continue | |
if hasattr(event, 'delta'): | |
await self.output_queue.put( | |
( | |
self.output_sample_rate, | |
np.frombuffer( | |
base64.b64decode(event.delta), dtype=np.int16 | |
).reshape(1, -1), | |
), | |
) | |
# Handle errors | |
elif event.type == "error": | |
print(f"[ERROR] {event}") | |
self.is_responding = False | |
# Handle function calls | |
elif event.type == "response.function_call_arguments.start": | |
print(f"Function call started") | |
self.function_call_in_progress = True | |
self.current_function_args = "" | |
self.current_call_id = getattr(event, 'call_id', None) | |
elif event.type == "response.function_call_arguments.delta": | |
if self.function_call_in_progress: | |
self.current_function_args += event.delta | |
elif event.type == "response.function_call_arguments.done": | |
if self.function_call_in_progress: | |
print(f"Function call done, args: {self.current_function_args}") | |
try: | |
args = json.loads(self.current_function_args) | |
query = args.get("query", "") | |
# Emit search event to client | |
await self.output_queue.put(AdditionalOutputs({ | |
"type": "search", | |
"query": query | |
})) | |
# Perform the search | |
search_results = await self.search_web(query) | |
print(f"Search results length: {len(search_results)}") | |
# Send function result back to the model | |
if self.connection and self.current_call_id: | |
await self.connection.conversation.item.create( | |
item={ | |
"type": "function_call_output", | |
"call_id": self.current_call_id, | |
"output": search_results | |
} | |
) | |
await self.connection.response.create() | |
except Exception as e: | |
print(f"Function call error: {e}") | |
finally: | |
self.function_call_in_progress = False | |
self.current_function_args = "" | |
self.current_call_id = None | |
async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
if not self.connection: | |
print(f"[RECEIVE] No connection, skipping") | |
return | |
try: | |
if frame is None or len(frame) < 2: | |
print(f"[RECEIVE] Invalid frame") | |
return | |
_, array = frame | |
if array is None: | |
print(f"[RECEIVE] Null array") | |
return | |
array = array.squeeze() | |
audio_message = base64.b64encode(array.tobytes()).decode("utf-8") | |
await self.connection.input_audio_buffer.append(audio=audio_message) | |
except Exception as e: | |
print(f"Error in receive: {e}") | |
async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None: | |
item = await wait_for_item(self.output_queue) | |
if isinstance(item, dict) and item.get('type') == 'text_message': | |
await self.process_text_message(item['content']) | |
return None | |
return item | |
async def shutdown(self) -> None: | |
print(f"[SHUTDOWN] Called") | |
if self.connection: | |
await self.connection.close() | |
self.connection = None | |
print("[REALTIME API] Connection closed") | |
# Create initial handler instance | |
handler = OpenAIHandler(web_search_enabled=False) | |
# Create components | |
chatbot = gr.Chatbot(type="messages") | |
# Create stream with handler instance | |
stream = Stream( | |
handler, | |
mode="send-receive", | |
modality="audio", | |
additional_inputs=[chatbot], | |
additional_outputs=[chatbot], | |
additional_outputs_handler=update_chatbot, | |
rtc_configuration=get_twilio_turn_credentials() if get_space() else None, | |
concurrency_limit=5 if get_space() else None, | |
time_limit=300 if get_space() else None, | |
) | |
app = FastAPI() | |
# Mount stream | |
stream.mount(app) | |
# Initialize database on startup | |
async def startup_event(): | |
try: | |
await PersonalAssistantDB.init() | |
print(f"Database initialized at: {DB_PATH}") | |
print(f"Persistent directory: {PERSISTENT_DIR}") | |
print(f"DB file exists: {os.path.exists(DB_PATH)}") | |
# Check if we're in Hugging Face Space | |
if os.path.exists("/data"): | |
print("Running in Hugging Face Space with persistent storage") | |
# List files in persistent directory | |
try: | |
files = os.listdir(PERSISTENT_DIR) | |
print(f"Files in persistent directory: {files}") | |
except Exception as e: | |
print(f"Error listing files: {e}") | |
except Exception as e: | |
print(f"Error during startup: {e}") | |
# Try to create directory if it doesn't exist | |
os.makedirs(PERSISTENT_DIR, exist_ok=True) | |
await PersonalAssistantDB.init() | |
# Intercept offer to capture settings | |
async def custom_offer(request: Request): | |
"""Intercept offer to capture settings""" | |
body = await request.json() | |
webrtc_id = body.get("webrtc_id") | |
web_search_enabled = body.get("web_search_enabled", False) | |
session_id = body.get("session_id") | |
user_name = body.get("user_name", "") | |
memories = body.get("memories", {}) | |
print(f"[OFFER] Received offer with webrtc_id: {webrtc_id}") | |
print(f"[OFFER] web_search_enabled: {web_search_enabled}") | |
print(f"[OFFER] session_id: {session_id}") | |
print(f"[OFFER] user_name: {user_name}") | |
print(f"[OFFER] memories categories: {list(memories.keys())}") | |
print(f"[OFFER] memories total items: {sum(len(items) for items in memories.values() if isinstance(items, list))}") | |
# 메모리가 비어있으면 DB에서 로드 | |
if not memories and session_id: | |
print(f"[OFFER] No memories received, loading from DB...") | |
memories_list = await PersonalAssistantDB.get_all_memories() | |
# 메모리를 카테고리별로 그룹화 | |
memories = {} | |
for memory in memories_list: | |
category = memory['category'] | |
if category not in memories: | |
memories[category] = [] | |
memories[category].append(memory['content']) | |
print(f"[OFFER] Loaded {len(memories_list)} memories from DB") | |
# Store settings with timestamp | |
if webrtc_id: | |
connection_settings[webrtc_id] = { | |
'web_search_enabled': web_search_enabled, | |
'session_id': session_id, | |
'user_name': user_name, | |
'memories': memories, # DB에서 로드한 메모리 저장 | |
'timestamp': asyncio.get_event_loop().time() | |
} | |
print(f"[OFFER] Stored settings for {webrtc_id} with {sum(len(items) for items in memories.values() if isinstance(items, list))} memory items") | |
# Remove our custom route temporarily | |
custom_route = None | |
for i, route in enumerate(app.routes): | |
if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer: | |
custom_route = app.routes.pop(i) | |
break | |
# Forward to stream's offer handler | |
print(f"[OFFER] Forwarding to stream.offer()") | |
response = await stream.offer(body) | |
# Re-add our custom route | |
if custom_route: | |
app.routes.insert(0, custom_route) | |
print(f"[OFFER] Response status: {response.get('status', 'unknown') if isinstance(response, dict) else 'OK'}") | |
return response | |
async def create_new_session(): | |
"""Create a new chat session""" | |
session_id = str(uuid.uuid4()) | |
await PersonalAssistantDB.create_session(session_id) | |
return {"session_id": session_id} | |
async def end_session(request: Request): | |
"""End session and extract memories""" | |
body = await request.json() | |
session_id = body.get("session_id") | |
if not session_id: | |
return {"error": "session_id required"} | |
# Extract and save memories from the conversation | |
await PersonalAssistantDB.extract_and_save_memories(session_id) | |
return {"status": "ok"} | |
async def save_message(request: Request): | |
"""Save a message to the database""" | |
body = await request.json() | |
session_id = body.get("session_id") | |
role = body.get("role") | |
content = body.get("content") | |
if not all([session_id, role, content]): | |
return {"error": "Missing required fields"} | |
await PersonalAssistantDB.save_message(session_id, role, content) | |
return {"status": "ok"} | |
async def get_recent_history(): | |
"""Get recent conversation history""" | |
conversations = await PersonalAssistantDB.get_recent_conversations() | |
return conversations | |
async def get_conversation(session_id: str): | |
"""Get messages for a specific conversation""" | |
messages = await PersonalAssistantDB.get_conversation_messages(session_id) | |
return messages | |
async def get_all_memories(): | |
"""Get all user memories""" | |
memories = await PersonalAssistantDB.get_all_memories() | |
return memories | |
async def chat_text(request: Request): | |
"""Handle text chat messages using GPT-4o-mini""" | |
try: | |
body = await request.json() | |
message = body.get("message", "") | |
web_search_enabled = body.get("web_search_enabled", False) | |
session_id = body.get("session_id") | |
user_name = body.get("user_name", "") | |
memories = body.get("memories", {}) | |
if not message: | |
return {"error": "메시지가 비어있습니다."} | |
# Process text chat | |
result = await process_text_chat(message, web_search_enabled, session_id, user_name, memories) | |
return result | |
except Exception as e: | |
print(f"Error in chat_text endpoint: {e}") | |
return {"error": "채팅 처리 중 오류가 발생했습니다."} | |
async def receive_text_message(webrtc_id: str, request: Request): | |
"""Receive text message from client""" | |
body = await request.json() | |
message = body.get("content", "") | |
# Find the handler for this connection | |
if webrtc_id in stream.handlers: | |
handler = stream.handlers[webrtc_id] | |
# Queue the text message for processing | |
await handler.output_queue.put({ | |
'type': 'text_message', | |
'content': message | |
}) | |
return {"status": "ok"} | |
async def outputs(webrtc_id: str): | |
"""Stream outputs including search events""" | |
async def output_stream(): | |
async for output in stream.output_stream(webrtc_id): | |
if hasattr(output, 'args') and output.args: | |
# Check if it's a search event | |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search': | |
yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n" | |
# Regular transcript event with language info | |
elif isinstance(output.args[0], dict) and 'event' in output.args[0]: | |
event_data = output.args[0] | |
if 'event' in event_data and hasattr(event_data['event'], 'transcript'): | |
data = { | |
"role": "assistant", | |
"content": event_data['event'].transcript, | |
"detected_language": event_data.get('detected_language') | |
} | |
yield f"event: output\ndata: {json.dumps(data)}\n\n" | |
return StreamingResponse(output_stream(), media_type="text/event-stream") | |
async def index(): | |
"""Serve the HTML page""" | |
rtc_config = get_twilio_turn_credentials() if get_space() else None | |
html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config)) | |
return HTMLResponse(content=html_content) | |
if __name__ == "__main__": | |
import uvicorn | |
mode = os.getenv("MODE") | |
if mode == "UI": | |
stream.ui.launch(server_port=7860) | |
elif mode == "PHONE": | |
stream.fastphone(host="0.0.0.0", port=7860) | |
else: | |
uvicorn.run(app, host="0.0.0.0", port=7860) |