seawolf2357 commited on
Commit
d142e85
·
verified ·
1 Parent(s): 60f7986

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +813 -55
app.py CHANGED
@@ -3,7 +3,6 @@ import base64
3
  import json
4
  from pathlib import Path
5
  import os
6
- import gradio as gr
7
  import numpy as np
8
  import openai
9
  from dotenv import load_dotenv
@@ -18,18 +17,659 @@ from fastrtc import (
18
  )
19
  from gradio.utils import get_space
20
  from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
 
 
21
 
22
  load_dotenv()
23
 
24
- cur_dir = Path(__file__).parent
25
-
26
  SAMPLE_RATE = 24000
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  class OpenAIHandler(AsyncStreamHandler):
30
- def __init__(
31
- self,
32
- ) -> None:
33
  super().__init__(
34
  expected_layout="mono",
35
  output_sample_rate=SAMPLE_RATE,
@@ -38,26 +678,77 @@ class OpenAIHandler(AsyncStreamHandler):
38
  )
39
  self.connection = None
40
  self.output_queue = asyncio.Queue()
 
 
 
 
41
 
42
  def copy(self):
43
- return OpenAIHandler()
44
 
45
- async def start_up(
46
- self,
47
- ):
48
- """Connect to realtime API. Run forever in separate thread to keep connection open."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  self.client = openai.AsyncOpenAI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  async with self.client.beta.realtime.connect(
51
  model="gpt-4o-mini-realtime-preview-2024-12-17"
52
  ) as conn:
53
- await conn.session.update(
54
- session={"turn_detection": {"type": "server_vad"}}
55
- )
 
 
 
 
 
56
  self.connection = conn
 
57
  async for event in self.connection:
58
  if event.type == "response.audio_transcript.done":
59
  await self.output_queue.put(AdditionalOutputs(event))
60
- if event.type == "response.audio.delta":
 
61
  await self.output_queue.put(
62
  (
63
  self.output_sample_rate,
@@ -66,6 +757,48 @@ class OpenAIHandler(AsyncStreamHandler):
66
  ).reshape(1, -1),
67
  ),
68
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
71
  if not self.connection:
@@ -73,7 +806,7 @@ class OpenAIHandler(AsyncStreamHandler):
73
  _, array = frame
74
  array = array.squeeze()
75
  audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
76
- await self.connection.input_audio_buffer.append(audio=audio_message) # type: ignore
77
 
78
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
79
  return await wait_for_item(self.output_queue)
@@ -84,58 +817,83 @@ class OpenAIHandler(AsyncStreamHandler):
84
  self.connection = None
85
 
86
 
87
- def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
88
- chatbot.append({"role": "assistant", "content": response.transcript})
89
- return chatbot
90
-
91
 
92
- chatbot = gr.Chatbot(type="messages")
93
- latest_message = gr.Textbox(type="text", visible=False)
94
- stream = Stream(
95
- OpenAIHandler(),
96
- mode="send-receive",
97
- modality="audio",
98
- additional_inputs=[chatbot],
99
- additional_outputs=[chatbot],
100
- additional_outputs_handler=update_chatbot,
101
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
102
- concurrency_limit=5 if get_space() else None,
103
- time_limit=90 if get_space() else None,
104
- )
105
 
106
  app = FastAPI()
107
 
108
- stream.mount(app)
 
 
 
109
 
110
 
111
- @app.get("/")
112
- async def _():
113
- rtc_config = get_twilio_turn_credentials() if get_space() else None
114
- html_content = (cur_dir / "index.html").read_text()
115
- html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
116
- return HTMLResponse(content=html_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @app.get("/outputs")
120
- def _(webrtc_id: str):
 
121
  async def output_stream():
122
- import json
123
-
124
- async for output in stream.output_stream(webrtc_id):
125
- s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
126
- yield f"event: output\ndata: {s}\n\n"
 
 
 
 
 
 
 
 
127
 
128
  return StreamingResponse(output_stream(), media_type="text/event-stream")
129
 
130
 
131
- if __name__ == "__main__":
132
- import os
 
 
 
 
133
 
134
- if (mode := os.getenv("MODE")) == "UI":
135
- stream.ui.launch(server_port=7860)
136
- elif mode == "PHONE":
137
- stream.fastphone(host="0.0.0.0", port=7860)
138
- else:
139
- import uvicorn
140
 
141
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
3
  import json
4
  from pathlib import Path
5
  import os
 
6
  import numpy as np
7
  import openai
8
  from dotenv import load_dotenv
 
17
  )
18
  from gradio.utils import get_space
19
  from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
20
+ import httpx
21
+ from typing import Optional, List, Dict
22
 
23
  load_dotenv()
24
 
 
 
25
  SAMPLE_RATE = 24000
26
 
27
+ # HTML content embedded as a string
28
+ HTML_CONTENT = """<!DOCTYPE html>
29
+ <html lang="ko">
30
+
31
+ <head>
32
+ <meta charset="UTF-8">
33
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
34
+ <title>MOUSE 음성 챗</title>
35
+ <style>
36
+ :root {
37
+ --primary-color: #6f42c1;
38
+ --secondary-color: #563d7c;
39
+ --dark-bg: #121212;
40
+ --card-bg: #1e1e1e;
41
+ --text-color: #f8f9fa;
42
+ --border-color: #333;
43
+ --hover-color: #8a5cf6;
44
+ }
45
+ body {
46
+ font-family: "SF Pro Display", -apple-system, BlinkMacSystemFont, sans-serif;
47
+ background-color: var(--dark-bg);
48
+ color: var(--text-color);
49
+ margin: 0;
50
+ padding: 0;
51
+ height: 100vh;
52
+ display: flex;
53
+ flex-direction: column;
54
+ overflow: hidden;
55
+ }
56
+ .container {
57
+ max-width: 900px;
58
+ margin: 0 auto;
59
+ padding: 20px;
60
+ flex-grow: 1;
61
+ display: flex;
62
+ flex-direction: column;
63
+ width: 100%;
64
+ height: calc(100vh - 40px);
65
+ box-sizing: border-box;
66
+ }
67
+ .header {
68
+ text-align: center;
69
+ padding: 20px 0;
70
+ border-bottom: 1px solid var(--border-color);
71
+ margin-bottom: 20px;
72
+ flex-shrink: 0;
73
+ }
74
+ .logo {
75
+ display: flex;
76
+ align-items: center;
77
+ justify-content: center;
78
+ gap: 10px;
79
+ }
80
+ .logo h1 {
81
+ margin: 0;
82
+ background: linear-gradient(135deg, var(--primary-color), #a78bfa);
83
+ -webkit-background-clip: text;
84
+ background-clip: text;
85
+ color: transparent;
86
+ font-size: 32px;
87
+ letter-spacing: 1px;
88
+ }
89
+ /* Web search toggle */
90
+ .search-toggle {
91
+ display: flex;
92
+ align-items: center;
93
+ justify-content: center;
94
+ gap: 10px;
95
+ margin-top: 15px;
96
+ }
97
+ .toggle-switch {
98
+ position: relative;
99
+ width: 50px;
100
+ height: 26px;
101
+ background-color: #ccc;
102
+ border-radius: 13px;
103
+ cursor: pointer;
104
+ transition: background-color 0.3s;
105
+ }
106
+ .toggle-switch.active {
107
+ background-color: var(--primary-color);
108
+ }
109
+ .toggle-slider {
110
+ position: absolute;
111
+ top: 3px;
112
+ left: 3px;
113
+ width: 20px;
114
+ height: 20px;
115
+ background-color: white;
116
+ border-radius: 50%;
117
+ transition: transform 0.3s;
118
+ }
119
+ .toggle-switch.active .toggle-slider {
120
+ transform: translateX(24px);
121
+ }
122
+ .search-label {
123
+ font-size: 14px;
124
+ color: #aaa;
125
+ }
126
+ .chat-container {
127
+ border-radius: 12px;
128
+ background-color: var(--card-bg);
129
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
130
+ padding: 20px;
131
+ flex-grow: 1;
132
+ display: flex;
133
+ flex-direction: column;
134
+ border: 1px solid var(--border-color);
135
+ overflow: hidden;
136
+ min-height: 0;
137
+ }
138
+ .chat-messages {
139
+ flex-grow: 1;
140
+ overflow-y: auto;
141
+ padding: 10px;
142
+ scrollbar-width: thin;
143
+ scrollbar-color: var(--primary-color) var(--card-bg);
144
+ min-height: 0;
145
+ }
146
+ .chat-messages::-webkit-scrollbar {
147
+ width: 6px;
148
+ }
149
+ .chat-messages::-webkit-scrollbar-thumb {
150
+ background-color: var(--primary-color);
151
+ border-radius: 6px;
152
+ }
153
+ .message {
154
+ margin-bottom: 20px;
155
+ padding: 14px;
156
+ border-radius: 8px;
157
+ font-size: 16px;
158
+ line-height: 1.6;
159
+ position: relative;
160
+ max-width: 80%;
161
+ animation: fade-in 0.3s ease-out;
162
+ }
163
+ @keyframes fade-in {
164
+ from {
165
+ opacity: 0;
166
+ transform: translateY(10px);
167
+ }
168
+ to {
169
+ opacity: 1;
170
+ transform: translateY(0);
171
+ }
172
+ }
173
+ .message.user {
174
+ background: linear-gradient(135deg, #2c3e50, #34495e);
175
+ margin-left: auto;
176
+ border-bottom-right-radius: 2px;
177
+ }
178
+ .message.assistant {
179
+ background: linear-gradient(135deg, var(--secondary-color), var(--primary-color));
180
+ margin-right: auto;
181
+ border-bottom-left-radius: 2px;
182
+ }
183
+ .message.search-result {
184
+ background: linear-gradient(135deg, #1a5a3e, #2e7d32);
185
+ font-size: 14px;
186
+ padding: 10px;
187
+ margin-bottom: 10px;
188
+ }
189
+ .controls {
190
+ text-align: center;
191
+ margin-top: 20px;
192
+ display: flex;
193
+ justify-content: center;
194
+ flex-shrink: 0;
195
+ }
196
+ button {
197
+ background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
198
+ color: white;
199
+ border: none;
200
+ padding: 14px 28px;
201
+ font-family: inherit;
202
+ font-size: 16px;
203
+ cursor: pointer;
204
+ transition: all 0.3s;
205
+ text-transform: uppercase;
206
+ letter-spacing: 1px;
207
+ border-radius: 50px;
208
+ display: flex;
209
+ align-items: center;
210
+ justify-content: center;
211
+ gap: 10px;
212
+ box-shadow: 0 4px 10px rgba(111, 66, 193, 0.3);
213
+ }
214
+ button:hover {
215
+ transform: translateY(-2px);
216
+ box-shadow: 0 6px 15px rgba(111, 66, 193, 0.5);
217
+ background: linear-gradient(135deg, var(--hover-color), var(--primary-color));
218
+ }
219
+ button:active {
220
+ transform: translateY(1px);
221
+ }
222
+ #audio-output {
223
+ display: none;
224
+ }
225
+ .icon-with-spinner {
226
+ display: flex;
227
+ align-items: center;
228
+ justify-content: center;
229
+ gap: 12px;
230
+ min-width: 180px;
231
+ }
232
+ .spinner {
233
+ width: 20px;
234
+ height: 20px;
235
+ border: 2px solid #ffffff;
236
+ border-top-color: transparent;
237
+ border-radius: 50%;
238
+ animation: spin 1s linear infinite;
239
+ flex-shrink: 0;
240
+ }
241
+ @keyframes spin {
242
+ to {
243
+ transform: rotate(360deg);
244
+ }
245
+ }
246
+ .audio-visualizer {
247
+ display: flex;
248
+ align-items: center;
249
+ justify-content: center;
250
+ gap: 5px;
251
+ min-width: 80px;
252
+ height: 25px;
253
+ }
254
+ .visualizer-bar {
255
+ width: 4px;
256
+ height: 100%;
257
+ background-color: rgba(255, 255, 255, 0.7);
258
+ border-radius: 2px;
259
+ transform-origin: bottom;
260
+ transform: scaleY(0.1);
261
+ transition: transform 0.1s ease;
262
+ }
263
+ .toast {
264
+ position: fixed;
265
+ top: 20px;
266
+ left: 50%;
267
+ transform: translateX(-50%);
268
+ padding: 16px 24px;
269
+ border-radius: 8px;
270
+ font-size: 14px;
271
+ z-index: 1000;
272
+ display: none;
273
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
274
+ }
275
+ .toast.error {
276
+ background-color: #f44336;
277
+ color: white;
278
+ }
279
+ .toast.warning {
280
+ background-color: #ff9800;
281
+ color: white;
282
+ }
283
+ .status-indicator {
284
+ display: inline-flex;
285
+ align-items: center;
286
+ margin-top: 10px;
287
+ font-size: 14px;
288
+ color: #aaa;
289
+ }
290
+ .status-dot {
291
+ width: 8px;
292
+ height: 8px;
293
+ border-radius: 50%;
294
+ margin-right: 8px;
295
+ }
296
+ .status-dot.connected {
297
+ background-color: #4caf50;
298
+ }
299
+ .status-dot.disconnected {
300
+ background-color: #f44336;
301
+ }
302
+ .status-dot.connecting {
303
+ background-color: #ff9800;
304
+ animation: pulse 1.5s infinite;
305
+ }
306
+ @keyframes pulse {
307
+ 0% {
308
+ opacity: 0.6;
309
+ }
310
+ 50% {
311
+ opacity: 1;
312
+ }
313
+ 100% {
314
+ opacity: 0.6;
315
+ }
316
+ }
317
+ .mouse-logo {
318
+ position: relative;
319
+ width: 40px;
320
+ height: 40px;
321
+ }
322
+ .mouse-ears {
323
+ position: absolute;
324
+ width: 15px;
325
+ height: 15px;
326
+ background-color: var(--primary-color);
327
+ border-radius: 50%;
328
+ }
329
+ .mouse-ear-left {
330
+ top: 0;
331
+ left: 5px;
332
+ }
333
+ .mouse-ear-right {
334
+ top: 0;
335
+ right: 5px;
336
+ }
337
+ .mouse-face {
338
+ position: absolute;
339
+ top: 10px;
340
+ left: 5px;
341
+ width: 30px;
342
+ height: 30px;
343
+ background-color: var(--secondary-color);
344
+ border-radius: 50%;
345
+ }
346
+ </style>
347
+ </head>
348
+
349
+ <body>
350
+ <div id="error-toast" class="toast"></div>
351
+ <div class="container">
352
+ <div class="header">
353
+ <div class="logo">
354
+ <div class="mouse-logo">
355
+ <div class="mouse-ears mouse-ear-left"></div>
356
+ <div class="mouse-ears mouse-ear-right"></div>
357
+ <div class="mouse-face"></div>
358
+ </div>
359
+ <h1>MOUSE 음성 챗</h1>
360
+ </div>
361
+ <div class="search-toggle">
362
+ <span class="search-label">웹 검색</span>
363
+ <div id="search-toggle" class="toggle-switch">
364
+ <div class="toggle-slider"></div>
365
+ </div>
366
+ </div>
367
+ <div class="status-indicator">
368
+ <div id="status-dot" class="status-dot disconnected"></div>
369
+ <span id="status-text">연결 대기 중</span>
370
+ </div>
371
+ </div>
372
+ <div class="chat-container">
373
+ <div class="chat-messages" id="chat-messages"></div>
374
+ </div>
375
+ <div class="controls">
376
+ <button id="start-button">대화 시작</button>
377
+ </div>
378
+ </div>
379
+ <audio id="audio-output"></audio>
380
+
381
+ <script>
382
+ let peerConnection;
383
+ let webrtc_id;
384
+ let webSearchEnabled = false;
385
+ const audioOutput = document.getElementById('audio-output');
386
+ const startButton = document.getElementById('start-button');
387
+ const chatMessages = document.getElementById('chat-messages');
388
+ const statusDot = document.getElementById('status-dot');
389
+ const statusText = document.getElementById('status-text');
390
+ const searchToggle = document.getElementById('search-toggle');
391
+ let audioLevel = 0;
392
+ let animationFrame;
393
+ let audioContext, analyser, audioSource;
394
+
395
+ // Web search toggle functionality
396
+ searchToggle.addEventListener('click', () => {
397
+ webSearchEnabled = !webSearchEnabled;
398
+ searchToggle.classList.toggle('active', webSearchEnabled);
399
+ });
400
+
401
+ function updateStatus(state) {
402
+ statusDot.className = 'status-dot ' + state;
403
+ if (state === 'connected') {
404
+ statusText.textContent = '연결됨';
405
+ } else if (state === 'connecting') {
406
+ statusText.textContent = '연결 중...';
407
+ } else {
408
+ statusText.textContent = '연결 대기 중';
409
+ }
410
+ }
411
+ function updateButtonState() {
412
+ const button = document.getElementById('start-button');
413
+ if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
414
+ button.innerHTML = `
415
+ <div class="icon-with-spinner">
416
+ <div class="spinner"></div>
417
+ <span>연결 중...</span>
418
+ </div>
419
+ `;
420
+ updateStatus('connecting');
421
+ } else if (peerConnection && peerConnection.connectionState === 'connected') {
422
+ button.innerHTML = `
423
+ <div class="icon-with-spinner">
424
+ <div class="audio-visualizer" id="audio-visualizer">
425
+ <div class="visualizer-bar"></div>
426
+ <div class="visualizer-bar"></div>
427
+ <div class="visualizer-bar"></div>
428
+ <div class="visualizer-bar"></div>
429
+ <div class="visualizer-bar"></div>
430
+ </div>
431
+ <span>대화 종료</span>
432
+ </div>
433
+ `;
434
+ updateStatus('connected');
435
+ } else {
436
+ button.innerHTML = '대화 시작';
437
+ updateStatus('disconnected');
438
+ }
439
+ }
440
+ function setupAudioVisualization(stream) {
441
+ audioContext = new (window.AudioContext || window.webkitAudioContext)();
442
+ analyser = audioContext.createAnalyser();
443
+ audioSource = audioContext.createMediaStreamSource(stream);
444
+ audioSource.connect(analyser);
445
+ analyser.fftSize = 256;
446
+ const bufferLength = analyser.frequencyBinCount;
447
+ const dataArray = new Uint8Array(bufferLength);
448
+
449
+ const visualizerBars = document.querySelectorAll('.visualizer-bar');
450
+ const barCount = visualizerBars.length;
451
+
452
+ function updateAudioLevel() {
453
+ analyser.getByteFrequencyData(dataArray);
454
+
455
+ for (let i = 0; i < barCount; i++) {
456
+ const start = Math.floor(i * (bufferLength / barCount));
457
+ const end = Math.floor((i + 1) * (bufferLength / barCount));
458
+
459
+ let sum = 0;
460
+ for (let j = start; j < end; j++) {
461
+ sum += dataArray[j];
462
+ }
463
+
464
+ const average = sum / (end - start) / 255;
465
+ const scaleY = 0.1 + average * 0.9;
466
+ visualizerBars[i].style.transform = `scaleY(${scaleY})`;
467
+ }
468
+
469
+ animationFrame = requestAnimationFrame(updateAudioLevel);
470
+ }
471
+
472
+ updateAudioLevel();
473
+ }
474
+ function showError(message) {
475
+ const toast = document.getElementById('error-toast');
476
+ toast.textContent = message;
477
+ toast.className = 'toast error';
478
+ toast.style.display = 'block';
479
+ setTimeout(() => {
480
+ toast.style.display = 'none';
481
+ }, 5000);
482
+ }
483
+ async function setupWebRTC() {
484
+ const config = __RTC_CONFIGURATION__;
485
+ peerConnection = new RTCPeerConnection(config);
486
+ const timeoutId = setTimeout(() => {
487
+ const toast = document.getElementById('error-toast');
488
+ toast.textContent = "연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?";
489
+ toast.className = 'toast warning';
490
+ toast.style.display = 'block';
491
+ setTimeout(() => {
492
+ toast.style.display = 'none';
493
+ }, 5000);
494
+ }, 5000);
495
+ try {
496
+ const stream = await navigator.mediaDevices.getUserMedia({
497
+ audio: true
498
+ });
499
+ setupAudioVisualization(stream);
500
+ stream.getTracks().forEach(track => {
501
+ peerConnection.addTrack(track, stream);
502
+ });
503
+ peerConnection.addEventListener('track', (evt) => {
504
+ if (audioOutput.srcObject !== evt.streams[0]) {
505
+ audioOutput.srcObject = evt.streams[0];
506
+ audioOutput.play();
507
+ }
508
+ });
509
+ const dataChannel = peerConnection.createDataChannel('text');
510
+ dataChannel.onmessage = (event) => {
511
+ const eventJson = JSON.parse(event.data);
512
+ if (eventJson.type === "error") {
513
+ showError(eventJson.message);
514
+ }
515
+ };
516
+ const offer = await peerConnection.createOffer();
517
+ await peerConnection.setLocalDescription(offer);
518
+ await new Promise((resolve) => {
519
+ if (peerConnection.iceGatheringState === "complete") {
520
+ resolve();
521
+ } else {
522
+ const checkState = () => {
523
+ if (peerConnection.iceGatheringState === "complete") {
524
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
525
+ resolve();
526
+ }
527
+ };
528
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
529
+ }
530
+ });
531
+ peerConnection.addEventListener('connectionstatechange', () => {
532
+ console.log('connectionstatechange', peerConnection.connectionState);
533
+ if (peerConnection.connectionState === 'connected') {
534
+ clearTimeout(timeoutId);
535
+ const toast = document.getElementById('error-toast');
536
+ toast.style.display = 'none';
537
+ }
538
+ updateButtonState();
539
+ });
540
+ webrtc_id = Math.random().toString(36).substring(7);
541
+ const response = await fetch('/webrtc/offer', {
542
+ method: 'POST',
543
+ headers: { 'Content-Type': 'application/json' },
544
+ body: JSON.stringify({
545
+ sdp: peerConnection.localDescription.sdp,
546
+ type: peerConnection.localDescription.type,
547
+ webrtc_id: webrtc_id,
548
+ web_search_enabled: webSearchEnabled
549
+ })
550
+ });
551
+ const serverResponse = await response.json();
552
+ if (serverResponse.status === 'failed') {
553
+ showError(serverResponse.meta.error === 'concurrency_limit_reached'
554
+ ? `너무 많은 연결입니다. 최대 한도는 ${serverResponse.meta.limit} 입니다.`
555
+ : serverResponse.meta.error);
556
+ stop();
557
+ return;
558
+ }
559
+ await peerConnection.setRemoteDescription(serverResponse);
560
+ const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
561
+ eventSource.addEventListener("output", (event) => {
562
+ const eventJson = JSON.parse(event.data);
563
+ addMessage("assistant", eventJson.content);
564
+ });
565
+ eventSource.addEventListener("search", (event) => {
566
+ const eventJson = JSON.parse(event.data);
567
+ if (eventJson.results) {
568
+ addMessage("search-result", `웹 ���색 중: "${eventJson.query}"`);
569
+ }
570
+ });
571
+ } catch (err) {
572
+ clearTimeout(timeoutId);
573
+ console.error('Error setting up WebRTC:', err);
574
+ showError('연결을 설정하지 못했습니다. 다시 시도해 주세요.');
575
+ stop();
576
+ }
577
+ }
578
+ function addMessage(role, content) {
579
+ const messageDiv = document.createElement('div');
580
+ messageDiv.classList.add('message', role);
581
+ messageDiv.textContent = content;
582
+ chatMessages.appendChild(messageDiv);
583
+ chatMessages.scrollTop = chatMessages.scrollHeight;
584
+ }
585
+ function stop() {
586
+ if (animationFrame) {
587
+ cancelAnimationFrame(animationFrame);
588
+ }
589
+ if (audioContext) {
590
+ audioContext.close();
591
+ audioContext = null;
592
+ analyser = null;
593
+ audioSource = null;
594
+ }
595
+ if (peerConnection) {
596
+ if (peerConnection.getTransceivers) {
597
+ peerConnection.getTransceivers().forEach(transceiver => {
598
+ if (transceiver.stop) {
599
+ transceiver.stop();
600
+ }
601
+ });
602
+ }
603
+ if (peerConnection.getSenders) {
604
+ peerConnection.getSenders().forEach(sender => {
605
+ if (sender.track && sender.track.stop) sender.track.stop();
606
+ });
607
+ }
608
+ console.log('closing');
609
+ peerConnection.close();
610
+ }
611
+ updateButtonState();
612
+ audioLevel = 0;
613
+ }
614
+ startButton.addEventListener('click', () => {
615
+ console.log('clicked');
616
+ console.log(peerConnection, peerConnection?.connectionState);
617
+ if (!peerConnection || peerConnection.connectionState !== 'connected') {
618
+ setupWebRTC();
619
+ } else {
620
+ console.log('stopping');
621
+ stop();
622
+ }
623
+ });
624
+ </script>
625
+ </body>
626
+
627
+ </html>"""
628
+
629
+
630
+ class BraveSearchClient:
631
+ """Brave Search API client"""
632
+ def __init__(self, api_key: str):
633
+ self.api_key = api_key
634
+ self.base_url = "https://api.search.brave.com/res/v1/web/search"
635
+
636
+ async def search(self, query: str, count: int = 10) -> List[Dict]:
637
+ """Perform a web search using Brave Search API"""
638
+ if not self.api_key:
639
+ return []
640
+
641
+ headers = {
642
+ "Accept": "application/json",
643
+ "X-Subscription-Token": self.api_key
644
+ }
645
+ params = {
646
+ "q": query,
647
+ "count": count,
648
+ "lang": "ko"
649
+ }
650
+
651
+ async with httpx.AsyncClient() as client:
652
+ try:
653
+ response = await client.get(self.base_url, headers=headers, params=params)
654
+ response.raise_for_status()
655
+ data = response.json()
656
+
657
+ results = []
658
+ if "web" in data and "results" in data["web"]:
659
+ for result in data["web"]["results"][:count]:
660
+ results.append({
661
+ "title": result.get("title", ""),
662
+ "url": result.get("url", ""),
663
+ "description": result.get("description", "")
664
+ })
665
+ return results
666
+ except Exception as e:
667
+ print(f"Brave Search error: {e}")
668
+ return []
669
+
670
 
671
  class OpenAIHandler(AsyncStreamHandler):
672
+ def __init__(self, web_search_enabled: bool = False, search_client: Optional[BraveSearchClient] = None) -> None:
 
 
673
  super().__init__(
674
  expected_layout="mono",
675
  output_sample_rate=SAMPLE_RATE,
 
678
  )
679
  self.connection = None
680
  self.output_queue = asyncio.Queue()
681
+ self.web_search_enabled = web_search_enabled
682
+ self.search_client = search_client
683
+ self.function_call_in_progress = False
684
+ self.current_function_args = ""
685
 
686
  def copy(self):
687
+ return OpenAIHandler(self.web_search_enabled, self.search_client)
688
 
689
+ async def search_web(self, query: str) -> str:
690
+ """Perform web search and return formatted results"""
691
+ if not self.search_client or not self.web_search_enabled:
692
+ return " 검색이 비활성화되어 있습니다."
693
+
694
+ results = await self.search_client.search(query)
695
+ if not results:
696
+ return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
697
+
698
+ # Format search results
699
+ formatted_results = []
700
+ for i, result in enumerate(results, 1):
701
+ formatted_results.append(
702
+ f"{i}. {result['title']}\n"
703
+ f" URL: {result['url']}\n"
704
+ f" {result['description']}\n"
705
+ )
706
+
707
+ return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
708
+
709
+ async def start_up(self):
710
+ """Connect to realtime API with function calling enabled"""
711
  self.client = openai.AsyncOpenAI()
712
+
713
+ # Define the web search function
714
+ tools = []
715
+ if self.web_search_enabled and self.search_client:
716
+ tools = [{
717
+ "type": "function",
718
+ "function": {
719
+ "name": "web_search",
720
+ "description": "Search the web for information",
721
+ "parameters": {
722
+ "type": "object",
723
+ "properties": {
724
+ "query": {
725
+ "type": "string",
726
+ "description": "The search query"
727
+ }
728
+ },
729
+ "required": ["query"]
730
+ }
731
+ }
732
+ }]
733
+
734
  async with self.client.beta.realtime.connect(
735
  model="gpt-4o-mini-realtime-preview-2024-12-17"
736
  ) as conn:
737
+ # Update session with tools
738
+ session_update = {
739
+ "turn_detection": {"type": "server_vad"},
740
+ "tools": tools,
741
+ "tool_choice": "auto" if tools else "none"
742
+ }
743
+
744
+ await conn.session.update(session=session_update)
745
  self.connection = conn
746
+
747
  async for event in self.connection:
748
  if event.type == "response.audio_transcript.done":
749
  await self.output_queue.put(AdditionalOutputs(event))
750
+
751
+ elif event.type == "response.audio.delta":
752
  await self.output_queue.put(
753
  (
754
  self.output_sample_rate,
 
757
  ).reshape(1, -1),
758
  ),
759
  )
760
+
761
+ # Handle function calls
762
+ elif event.type == "response.function_call_arguments.start":
763
+ self.function_call_in_progress = True
764
+ self.current_function_args = ""
765
+
766
+ elif event.type == "response.function_call_arguments.delta":
767
+ if self.function_call_in_progress:
768
+ self.current_function_args += event.delta
769
+
770
+ elif event.type == "response.function_call_arguments.done":
771
+ if self.function_call_in_progress:
772
+ try:
773
+ args = json.loads(self.current_function_args)
774
+ query = args.get("query", "")
775
+
776
+ # Emit search event to client
777
+ await self.output_queue.put(AdditionalOutputs({
778
+ "type": "search",
779
+ "query": query,
780
+ "results": True
781
+ }))
782
+
783
+ # Perform the search
784
+ search_results = await self.search_web(query)
785
+
786
+ # Send function result back to the model
787
+ if self.connection:
788
+ await self.connection.conversation.item.create(
789
+ item={
790
+ "type": "function_call_output",
791
+ "call_id": event.call_id,
792
+ "output": search_results
793
+ }
794
+ )
795
+ await self.connection.response.create()
796
+
797
+ except Exception as e:
798
+ print(f"Function call error: {e}")
799
+ finally:
800
+ self.function_call_in_progress = False
801
+ self.current_function_args = ""
802
 
803
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
804
  if not self.connection:
 
806
  _, array = frame
807
  array = array.squeeze()
808
  audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
809
+ await self.connection.input_audio_buffer.append(audio=audio_message)
810
 
811
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
812
  return await wait_for_item(self.output_queue)
 
817
  self.connection = None
818
 
819
 
820
+ # Store active handlers by webrtc_id
821
+ active_handlers = {}
 
 
822
 
823
+ # Initialize search client
824
+ brave_api_key = os.getenv("BSERACH_API")
825
+ search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
 
 
 
 
 
 
 
 
 
 
826
 
827
  app = FastAPI()
828
 
829
+
830
+ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
831
+ chatbot.append({"role": "assistant", "content": response.transcript})
832
+ return chatbot
833
 
834
 
835
+ @app.post("/webrtc/offer")
836
+ async def webrtc_offer(request: dict):
837
+ """Handle WebRTC offer with web search preference"""
838
+ web_search_enabled = request.get("web_search_enabled", False)
839
+ webrtc_id = request.get("webrtc_id")
840
+
841
+ # Create handler with web search capability
842
+ handler = OpenAIHandler(web_search_enabled=web_search_enabled, search_client=search_client)
843
+ active_handlers[webrtc_id] = handler
844
+
845
+ # Create stream for this connection
846
+ stream = Stream(
847
+ handler,
848
+ mode="send-receive",
849
+ modality="audio",
850
+ additional_inputs=[[]], # Empty chatbot state
851
+ additional_outputs=[[]],
852
+ additional_outputs_handler=update_chatbot,
853
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
854
+ concurrency_limit=5 if get_space() else None,
855
+ time_limit=90 if get_space() else None,
856
+ )
857
+
858
+ # Store stream reference
859
+ handler.stream = stream
860
+
861
+ # Mount and handle offer
862
+ stream.mount(app)
863
+
864
+ # Forward the WebRTC offer to the stream
865
+ return await stream.offer(request)
866
 
867
 
868
  @app.get("/outputs")
869
+ async def outputs(webrtc_id: str):
870
+ """Stream outputs including search events"""
871
  async def output_stream():
872
+ handler = active_handlers.get(webrtc_id)
873
+ if not handler or not hasattr(handler, 'stream'):
874
+ return
875
+
876
+ async for output in handler.stream.output_stream(webrtc_id):
877
+ if hasattr(output, 'args') and output.args:
878
+ # Check if it's a search event
879
+ if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
880
+ yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
881
+ # Regular transcript event
882
+ elif hasattr(output.args[0], 'transcript'):
883
+ s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
884
+ yield f"event: output\ndata: {s}\n\n"
885
 
886
  return StreamingResponse(output_stream(), media_type="text/event-stream")
887
 
888
 
889
+ @app.get("/")
890
+ async def index():
891
+ """Serve the HTML page"""
892
+ rtc_config = get_twilio_turn_credentials() if get_space() else None
893
+ html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
894
+ return HTMLResponse(content=html_content)
895
 
 
 
 
 
 
 
896
 
897
+ if __name__ == "__main__":
898
+ import uvicorn
899
+ uvicorn.run(app, host="0.0.0.0", port=7860)