seawolf2357 commited on
Commit
701a528
·
verified ·
1 Parent(s): def2999

Delete app-backup.py

Browse files
Files changed (1) hide show
  1. app-backup.py +0 -990
app-backup.py DELETED
@@ -1,990 +0,0 @@
1
- import asyncio
2
- import base64
3
- import json
4
- from pathlib import Path
5
- import os
6
- import numpy as np
7
- import openai
8
- from dotenv import load_dotenv
9
- from fastapi import FastAPI, Request
10
- from fastapi.responses import HTMLResponse, StreamingResponse
11
- from fastrtc import (
12
- AdditionalOutputs,
13
- AsyncStreamHandler,
14
- Stream,
15
- get_twilio_turn_credentials,
16
- wait_for_item,
17
- )
18
- from gradio.utils import get_space
19
- from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
20
- import httpx
21
- from typing import Optional, List, Dict
22
- import gradio as gr
23
-
24
- load_dotenv()
25
-
26
- SAMPLE_RATE = 24000
27
-
28
- # HTML content embedded as a string
29
- HTML_CONTENT = """<!DOCTYPE html>
30
- <html lang="ko">
31
-
32
- <head>
33
- <meta charset="UTF-8">
34
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
35
- <title>MOUSE 음성 챗</title>
36
- <style>
37
- :root {
38
- --primary-color: #6f42c1;
39
- --secondary-color: #563d7c;
40
- --dark-bg: #121212;
41
- --card-bg: #1e1e1e;
42
- --text-color: #f8f9fa;
43
- --border-color: #333;
44
- --hover-color: #8a5cf6;
45
- }
46
- body {
47
- font-family: "SF Pro Display", -apple-system, BlinkMacSystemFont, sans-serif;
48
- background-color: var(--dark-bg);
49
- color: var(--text-color);
50
- margin: 0;
51
- padding: 0;
52
- height: 100vh;
53
- display: flex;
54
- flex-direction: column;
55
- overflow: hidden;
56
- }
57
- .container {
58
- max-width: 900px;
59
- margin: 0 auto;
60
- padding: 20px;
61
- flex-grow: 1;
62
- display: flex;
63
- flex-direction: column;
64
- width: 100%;
65
- height: calc(100vh - 40px);
66
- box-sizing: border-box;
67
- }
68
- .header {
69
- text-align: center;
70
- padding: 20px 0;
71
- border-bottom: 1px solid var(--border-color);
72
- margin-bottom: 20px;
73
- flex-shrink: 0;
74
- }
75
- .logo {
76
- display: flex;
77
- align-items: center;
78
- justify-content: center;
79
- gap: 10px;
80
- }
81
- .logo h1 {
82
- margin: 0;
83
- background: linear-gradient(135deg, var(--primary-color), #a78bfa);
84
- -webkit-background-clip: text;
85
- background-clip: text;
86
- color: transparent;
87
- font-size: 32px;
88
- letter-spacing: 1px;
89
- }
90
- /* Web search toggle */
91
- .search-toggle {
92
- display: flex;
93
- align-items: center;
94
- justify-content: center;
95
- gap: 10px;
96
- margin-top: 15px;
97
- }
98
- .toggle-switch {
99
- position: relative;
100
- width: 50px;
101
- height: 26px;
102
- background-color: #ccc;
103
- border-radius: 13px;
104
- cursor: pointer;
105
- transition: background-color 0.3s;
106
- }
107
- .toggle-switch.active {
108
- background-color: var(--primary-color);
109
- }
110
- .toggle-slider {
111
- position: absolute;
112
- top: 3px;
113
- left: 3px;
114
- width: 20px;
115
- height: 20px;
116
- background-color: white;
117
- border-radius: 50%;
118
- transition: transform 0.3s;
119
- }
120
- .toggle-switch.active .toggle-slider {
121
- transform: translateX(24px);
122
- }
123
- .search-label {
124
- font-size: 14px;
125
- color: #aaa;
126
- }
127
- .chat-container {
128
- border-radius: 12px;
129
- background-color: var(--card-bg);
130
- box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
131
- padding: 20px;
132
- flex-grow: 1;
133
- display: flex;
134
- flex-direction: column;
135
- border: 1px solid var(--border-color);
136
- overflow: hidden;
137
- min-height: 0;
138
- }
139
- .chat-messages {
140
- flex-grow: 1;
141
- overflow-y: auto;
142
- padding: 10px;
143
- scrollbar-width: thin;
144
- scrollbar-color: var(--primary-color) var(--card-bg);
145
- min-height: 0;
146
- }
147
- .chat-messages::-webkit-scrollbar {
148
- width: 6px;
149
- }
150
- .chat-messages::-webkit-scrollbar-thumb {
151
- background-color: var(--primary-color);
152
- border-radius: 6px;
153
- }
154
- .message {
155
- margin-bottom: 20px;
156
- padding: 14px;
157
- border-radius: 8px;
158
- font-size: 16px;
159
- line-height: 1.6;
160
- position: relative;
161
- max-width: 80%;
162
- animation: fade-in 0.3s ease-out;
163
- }
164
- @keyframes fade-in {
165
- from {
166
- opacity: 0;
167
- transform: translateY(10px);
168
- }
169
- to {
170
- opacity: 1;
171
- transform: translateY(0);
172
- }
173
- }
174
- .message.user {
175
- background: linear-gradient(135deg, #2c3e50, #34495e);
176
- margin-left: auto;
177
- border-bottom-right-radius: 2px;
178
- }
179
- .message.assistant {
180
- background: linear-gradient(135deg, var(--secondary-color), var(--primary-color));
181
- margin-right: auto;
182
- border-bottom-left-radius: 2px;
183
- }
184
- .message.search-result {
185
- background: linear-gradient(135deg, #1a5a3e, #2e7d32);
186
- font-size: 14px;
187
- padding: 10px;
188
- margin-bottom: 10px;
189
- }
190
- .controls {
191
- text-align: center;
192
- margin-top: 20px;
193
- display: flex;
194
- justify-content: center;
195
- flex-shrink: 0;
196
- }
197
- button {
198
- background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
199
- color: white;
200
- border: none;
201
- padding: 14px 28px;
202
- font-family: inherit;
203
- font-size: 16px;
204
- cursor: pointer;
205
- transition: all 0.3s;
206
- text-transform: uppercase;
207
- letter-spacing: 1px;
208
- border-radius: 50px;
209
- display: flex;
210
- align-items: center;
211
- justify-content: center;
212
- gap: 10px;
213
- box-shadow: 0 4px 10px rgba(111, 66, 193, 0.3);
214
- }
215
- button:hover {
216
- transform: translateY(-2px);
217
- box-shadow: 0 6px 15px rgba(111, 66, 193, 0.5);
218
- background: linear-gradient(135deg, var(--hover-color), var(--primary-color));
219
- }
220
- button:active {
221
- transform: translateY(1px);
222
- }
223
- #audio-output {
224
- display: none;
225
- }
226
- .icon-with-spinner {
227
- display: flex;
228
- align-items: center;
229
- justify-content: center;
230
- gap: 12px;
231
- min-width: 180px;
232
- }
233
- .spinner {
234
- width: 20px;
235
- height: 20px;
236
- border: 2px solid #ffffff;
237
- border-top-color: transparent;
238
- border-radius: 50%;
239
- animation: spin 1s linear infinite;
240
- flex-shrink: 0;
241
- }
242
- @keyframes spin {
243
- to {
244
- transform: rotate(360deg);
245
- }
246
- }
247
- .audio-visualizer {
248
- display: flex;
249
- align-items: center;
250
- justify-content: center;
251
- gap: 5px;
252
- min-width: 80px;
253
- height: 25px;
254
- }
255
- .visualizer-bar {
256
- width: 4px;
257
- height: 100%;
258
- background-color: rgba(255, 255, 255, 0.7);
259
- border-radius: 2px;
260
- transform-origin: bottom;
261
- transform: scaleY(0.1);
262
- transition: transform 0.1s ease;
263
- }
264
- .toast {
265
- position: fixed;
266
- top: 20px;
267
- left: 50%;
268
- transform: translateX(-50%);
269
- padding: 16px 24px;
270
- border-radius: 8px;
271
- font-size: 14px;
272
- z-index: 1000;
273
- display: none;
274
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
275
- }
276
- .toast.error {
277
- background-color: #f44336;
278
- color: white;
279
- }
280
- .toast.warning {
281
- background-color: #ff9800;
282
- color: white;
283
- }
284
- .status-indicator {
285
- display: inline-flex;
286
- align-items: center;
287
- margin-top: 10px;
288
- font-size: 14px;
289
- color: #aaa;
290
- }
291
- .status-dot {
292
- width: 8px;
293
- height: 8px;
294
- border-radius: 50%;
295
- margin-right: 8px;
296
- }
297
- .status-dot.connected {
298
- background-color: #4caf50;
299
- }
300
- .status-dot.disconnected {
301
- background-color: #f44336;
302
- }
303
- .status-dot.connecting {
304
- background-color: #ff9800;
305
- animation: pulse 1.5s infinite;
306
- }
307
- @keyframes pulse {
308
- 0% {
309
- opacity: 0.6;
310
- }
311
- 50% {
312
- opacity: 1;
313
- }
314
- 100% {
315
- opacity: 0.6;
316
- }
317
- }
318
- .mouse-logo {
319
- position: relative;
320
- width: 40px;
321
- height: 40px;
322
- }
323
- .mouse-ears {
324
- position: absolute;
325
- width: 15px;
326
- height: 15px;
327
- background-color: var(--primary-color);
328
- border-radius: 50%;
329
- }
330
- .mouse-ear-left {
331
- top: 0;
332
- left: 5px;
333
- }
334
- .mouse-ear-right {
335
- top: 0;
336
- right: 5px;
337
- }
338
- .mouse-face {
339
- position: absolute;
340
- top: 10px;
341
- left: 5px;
342
- width: 30px;
343
- height: 30px;
344
- background-color: var(--secondary-color);
345
- border-radius: 50%;
346
- }
347
- </style>
348
- </head>
349
-
350
- <body>
351
- <div id="error-toast" class="toast"></div>
352
- <div class="container">
353
- <div class="header">
354
- <div class="logo">
355
- <div class="mouse-logo">
356
- <div class="mouse-ears mouse-ear-left"></div>
357
- <div class="mouse-ears mouse-ear-right"></div>
358
- <div class="mouse-face"></div>
359
- </div>
360
- <h1>MOUSE 음성 챗</h1>
361
- </div>
362
- <div class="search-toggle">
363
- <span class="search-label">웹 검색</span>
364
- <div id="search-toggle" class="toggle-switch">
365
- <div class="toggle-slider"></div>
366
- </div>
367
- </div>
368
- <div class="status-indicator">
369
- <div id="status-dot" class="status-dot disconnected"></div>
370
- <span id="status-text">연결 대기 중</span>
371
- </div>
372
- </div>
373
- <div class="chat-container">
374
- <div class="chat-messages" id="chat-messages"></div>
375
- </div>
376
- <div class="controls">
377
- <button id="start-button">대화 시작</button>
378
- </div>
379
- </div>
380
- <audio id="audio-output"></audio>
381
-
382
- <script>
383
- let peerConnection;
384
- let webrtc_id;
385
- let webSearchEnabled = false;
386
- const audioOutput = document.getElementById('audio-output');
387
- const startButton = document.getElementById('start-button');
388
- const chatMessages = document.getElementById('chat-messages');
389
- const statusDot = document.getElementById('status-dot');
390
- const statusText = document.getElementById('status-text');
391
- const searchToggle = document.getElementById('search-toggle');
392
- let audioLevel = 0;
393
- let animationFrame;
394
- let audioContext, analyser, audioSource;
395
-
396
- // Web search toggle functionality
397
- searchToggle.addEventListener('click', () => {
398
- webSearchEnabled = !webSearchEnabled;
399
- searchToggle.classList.toggle('active', webSearchEnabled);
400
- console.log('Web search enabled:', webSearchEnabled);
401
- });
402
-
403
- function updateStatus(state) {
404
- statusDot.className = 'status-dot ' + state;
405
- if (state === 'connected') {
406
- statusText.textContent = '연결됨';
407
- } else if (state === 'connecting') {
408
- statusText.textContent = '연결 중...';
409
- } else {
410
- statusText.textContent = '연결 대기 중';
411
- }
412
- }
413
- function updateButtonState() {
414
- const button = document.getElementById('start-button');
415
- if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
416
- button.innerHTML = `
417
- <div class="icon-with-spinner">
418
- <div class="spinner"></div>
419
- <span>연결 중...</span>
420
- </div>
421
- `;
422
- updateStatus('connecting');
423
- } else if (peerConnection && peerConnection.connectionState === 'connected') {
424
- button.innerHTML = `
425
- <div class="icon-with-spinner">
426
- <div class="audio-visualizer" id="audio-visualizer">
427
- <div class="visualizer-bar"></div>
428
- <div class="visualizer-bar"></div>
429
- <div class="visualizer-bar"></div>
430
- <div class="visualizer-bar"></div>
431
- <div class="visualizer-bar"></div>
432
- </div>
433
- <span>대화 종료</span>
434
- </div>
435
- `;
436
- updateStatus('connected');
437
- } else {
438
- button.innerHTML = '대화 시작';
439
- updateStatus('disconnected');
440
- }
441
- }
442
- function setupAudioVisualization(stream) {
443
- audioContext = new (window.AudioContext || window.webkitAudioContext)();
444
- analyser = audioContext.createAnalyser();
445
- audioSource = audioContext.createMediaStreamSource(stream);
446
- audioSource.connect(analyser);
447
- analyser.fftSize = 256;
448
- const bufferLength = analyser.frequencyBinCount;
449
- const dataArray = new Uint8Array(bufferLength);
450
-
451
- const visualizerBars = document.querySelectorAll('.visualizer-bar');
452
- const barCount = visualizerBars.length;
453
-
454
- function updateAudioLevel() {
455
- analyser.getByteFrequencyData(dataArray);
456
-
457
- for (let i = 0; i < barCount; i++) {
458
- const start = Math.floor(i * (bufferLength / barCount));
459
- const end = Math.floor((i + 1) * (bufferLength / barCount));
460
-
461
- let sum = 0;
462
- for (let j = start; j < end; j++) {
463
- sum += dataArray[j];
464
- }
465
-
466
- const average = sum / (end - start) / 255;
467
- const scaleY = 0.1 + average * 0.9;
468
- visualizerBars[i].style.transform = `scaleY(${scaleY})`;
469
- }
470
-
471
- animationFrame = requestAnimationFrame(updateAudioLevel);
472
- }
473
-
474
- updateAudioLevel();
475
- }
476
- function showError(message) {
477
- const toast = document.getElementById('error-toast');
478
- toast.textContent = message;
479
- toast.className = 'toast error';
480
- toast.style.display = 'block';
481
- setTimeout(() => {
482
- toast.style.display = 'none';
483
- }, 5000);
484
- }
485
- async function setupWebRTC() {
486
- const config = __RTC_CONFIGURATION__;
487
- peerConnection = new RTCPeerConnection(config);
488
- const timeoutId = setTimeout(() => {
489
- const toast = document.getElementById('error-toast');
490
- toast.textContent = "연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?";
491
- toast.className = 'toast warning';
492
- toast.style.display = 'block';
493
- setTimeout(() => {
494
- toast.style.display = 'none';
495
- }, 5000);
496
- }, 5000);
497
- try {
498
- const stream = await navigator.mediaDevices.getUserMedia({
499
- audio: true
500
- });
501
- setupAudioVisualization(stream);
502
- stream.getTracks().forEach(track => {
503
- peerConnection.addTrack(track, stream);
504
- });
505
- peerConnection.addEventListener('track', (evt) => {
506
- if (audioOutput.srcObject !== evt.streams[0]) {
507
- audioOutput.srcObject = evt.streams[0];
508
- audioOutput.play();
509
- }
510
- });
511
- const dataChannel = peerConnection.createDataChannel('text');
512
- dataChannel.onmessage = (event) => {
513
- const eventJson = JSON.parse(event.data);
514
- if (eventJson.type === "error") {
515
- showError(eventJson.message);
516
- }
517
- };
518
- const offer = await peerConnection.createOffer();
519
- await peerConnection.setLocalDescription(offer);
520
- await new Promise((resolve) => {
521
- if (peerConnection.iceGatheringState === "complete") {
522
- resolve();
523
- } else {
524
- const checkState = () => {
525
- if (peerConnection.iceGatheringState === "complete") {
526
- peerConnection.removeEventListener("icegatheringstatechange", checkState);
527
- resolve();
528
- }
529
- };
530
- peerConnection.addEventListener("icegatheringstatechange", checkState);
531
- }
532
- });
533
- peerConnection.addEventListener('connectionstatechange', () => {
534
- console.log('connectionstatechange', peerConnection.connectionState);
535
- if (peerConnection.connectionState === 'connected') {
536
- clearTimeout(timeoutId);
537
- const toast = document.getElementById('error-toast');
538
- toast.style.display = 'none';
539
- }
540
- updateButtonState();
541
- });
542
- webrtc_id = Math.random().toString(36).substring(7);
543
- const response = await fetch('/webrtc/offer', {
544
- method: 'POST',
545
- headers: { 'Content-Type': 'application/json' },
546
- body: JSON.stringify({
547
- sdp: peerConnection.localDescription.sdp,
548
- type: peerConnection.localDescription.type,
549
- webrtc_id: webrtc_id,
550
- web_search_enabled: webSearchEnabled
551
- })
552
- });
553
- const serverResponse = await response.json();
554
- if (serverResponse.status === 'failed') {
555
- showError(serverResponse.meta.error === 'concurrency_limit_reached'
556
- ? `너무 많은 연결입니다. 최대 한도는 ${serverResponse.meta.limit} 입니다.`
557
- : serverResponse.meta.error);
558
- stop();
559
- return;
560
- }
561
- await peerConnection.setRemoteDescription(serverResponse);
562
- const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
563
- eventSource.addEventListener("output", (event) => {
564
- const eventJson = JSON.parse(event.data);
565
- addMessage("assistant", eventJson.content);
566
- });
567
- eventSource.addEventListener("search", (event) => {
568
- const eventJson = JSON.parse(event.data);
569
- if (eventJson.query) {
570
- addMessage("search-result", `웹 검색 중: "${eventJson.query}"`);
571
- }
572
- });
573
- } catch (err) {
574
- clearTimeout(timeoutId);
575
- console.error('Error setting up WebRTC:', err);
576
- showError('연결을 설정하지 못했습니다. 다시 시도해 주세요.');
577
- stop();
578
- }
579
- }
580
- function addMessage(role, content) {
581
- const messageDiv = document.createElement('div');
582
- messageDiv.classList.add('message', role);
583
- messageDiv.textContent = content;
584
- chatMessages.appendChild(messageDiv);
585
- chatMessages.scrollTop = chatMessages.scrollHeight;
586
- }
587
- function stop() {
588
- if (animationFrame) {
589
- cancelAnimationFrame(animationFrame);
590
- }
591
- if (audioContext) {
592
- audioContext.close();
593
- audioContext = null;
594
- analyser = null;
595
- audioSource = null;
596
- }
597
- if (peerConnection) {
598
- if (peerConnection.getTransceivers) {
599
- peerConnection.getTransceivers().forEach(transceiver => {
600
- if (transceiver.stop) {
601
- transceiver.stop();
602
- }
603
- });
604
- }
605
- if (peerConnection.getSenders) {
606
- peerConnection.getSenders().forEach(sender => {
607
- if (sender.track && sender.track.stop) sender.track.stop();
608
- });
609
- }
610
- console.log('closing');
611
- peerConnection.close();
612
- }
613
- updateButtonState();
614
- audioLevel = 0;
615
- }
616
- startButton.addEventListener('click', () => {
617
- console.log('clicked');
618
- console.log(peerConnection, peerConnection?.connectionState);
619
- if (!peerConnection || peerConnection.connectionState !== 'connected') {
620
- setupWebRTC();
621
- } else {
622
- console.log('stopping');
623
- stop();
624
- }
625
- });
626
- </script>
627
- </body>
628
-
629
- </html>"""
630
-
631
-
632
- class BraveSearchClient:
633
- """Brave Search API client"""
634
- def __init__(self, api_key: str):
635
- self.api_key = api_key
636
- self.base_url = "https://api.search.brave.com/res/v1/web/search"
637
-
638
- async def search(self, query: str, count: int = 10) -> List[Dict]:
639
- """Perform a web search using Brave Search API"""
640
- if not self.api_key:
641
- return []
642
-
643
- headers = {
644
- "Accept": "application/json",
645
- "X-Subscription-Token": self.api_key
646
- }
647
- params = {
648
- "q": query,
649
- "count": count,
650
- "lang": "ko"
651
- }
652
-
653
- async with httpx.AsyncClient() as client:
654
- try:
655
- response = await client.get(self.base_url, headers=headers, params=params)
656
- response.raise_for_status()
657
- data = response.json()
658
-
659
- results = []
660
- if "web" in data and "results" in data["web"]:
661
- for result in data["web"]["results"][:count]:
662
- results.append({
663
- "title": result.get("title", ""),
664
- "url": result.get("url", ""),
665
- "description": result.get("description", "")
666
- })
667
- return results
668
- except Exception as e:
669
- print(f"Brave Search error: {e}")
670
- return []
671
-
672
-
673
- # Initialize search client globally
674
- brave_api_key = os.getenv("BSEARCH_API")
675
- search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
676
- print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
677
-
678
- # Store web search settings by connection
679
- web_search_settings = {}
680
-
681
- def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
682
- chatbot.append({"role": "assistant", "content": response.transcript})
683
- return chatbot
684
-
685
-
686
- class OpenAIHandler(AsyncStreamHandler):
687
- def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None:
688
- super().__init__(
689
- expected_layout="mono",
690
- output_sample_rate=SAMPLE_RATE,
691
- output_frame_size=480,
692
- input_sample_rate=SAMPLE_RATE,
693
- )
694
- self.connection = None
695
- self.output_queue = asyncio.Queue()
696
- self.search_client = search_client
697
- self.function_call_in_progress = False
698
- self.current_function_args = ""
699
- self.current_call_id = None
700
- self.webrtc_id = webrtc_id
701
- self.web_search_enabled = web_search_enabled
702
- print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
703
-
704
- def copy(self):
705
- # Get the most recent settings
706
- if web_search_settings:
707
- # Get the most recent webrtc_id
708
- recent_ids = sorted(web_search_settings.keys(),
709
- key=lambda k: web_search_settings[k].get('timestamp', 0),
710
- reverse=True)
711
- if recent_ids:
712
- recent_id = recent_ids[0]
713
- settings = web_search_settings[recent_id]
714
- web_search_enabled = settings.get('enabled', False)
715
- print(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}")
716
- return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id)
717
-
718
- print(f"Handler.copy() called - creating new handler with default settings")
719
- return OpenAIHandler(web_search_enabled=False)
720
-
721
- async def search_web(self, query: str) -> str:
722
- """Perform web search and return formatted results"""
723
- if not self.search_client or not self.web_search_enabled:
724
- return "웹 검색이 비활성화되어 있습니다."
725
-
726
- print(f"Searching web for: {query}")
727
- results = await self.search_client.search(query)
728
- if not results:
729
- return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
730
-
731
- # Format search results
732
- formatted_results = []
733
- for i, result in enumerate(results, 1):
734
- formatted_results.append(
735
- f"{i}. {result['title']}\n"
736
- f" URL: {result['url']}\n"
737
- f" {result['description']}\n"
738
- )
739
-
740
- return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
741
-
742
- async def start_up(self):
743
- """Connect to realtime API with function calling enabled"""
744
- # First check if we have the most recent settings
745
- if web_search_settings:
746
- recent_ids = sorted(web_search_settings.keys(),
747
- key=lambda k: web_search_settings[k].get('timestamp', 0),
748
- reverse=True)
749
- if recent_ids:
750
- recent_id = recent_ids[0]
751
- settings = web_search_settings[recent_id]
752
- self.web_search_enabled = settings.get('enabled', False)
753
- self.webrtc_id = recent_id
754
- print(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}")
755
-
756
- print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
757
- self.client = openai.AsyncOpenAI()
758
-
759
- # Define the web search function
760
- tools = []
761
- instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean."
762
-
763
- if self.web_search_enabled and self.search_client:
764
- tools = [{
765
- "type": "function",
766
- "function": {
767
- "name": "web_search",
768
- "description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.",
769
- "parameters": {
770
- "type": "object",
771
- "properties": {
772
- "query": {
773
- "type": "string",
774
- "description": "The search query in Korean or English"
775
- }
776
- },
777
- "required": ["query"]
778
- }
779
- }
780
- }]
781
- print("Web search function added to tools")
782
-
783
- instructions = (
784
- "You are a helpful assistant with web search capabilities. "
785
- "IMPORTANT: You MUST use the web_search function for ANY of these topics:\n"
786
- "- Weather (날씨, 기온, 비, 눈)\n"
787
- "- News (뉴스, 소식)\n"
788
- "- Current events (현재, 최근, 오늘, 지금)\n"
789
- "- Prices (가격, 환율, 주가)\n"
790
- "- Sports scores or results\n"
791
- "- Any question about 2024 or 2025\n"
792
- "- Any time-sensitive information\n\n"
793
- "When in doubt, USE web_search. It's better to search and provide accurate information "
794
- "than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
795
- )
796
-
797
- async with self.client.beta.realtime.connect(
798
- model="gpt-4o-mini-realtime-preview-2024-12-17"
799
- ) as conn:
800
- # Update session with tools
801
- session_update = {
802
- "turn_detection": {"type": "server_vad"},
803
- "instructions": instructions,
804
- "tools": tools,
805
- "tool_choice": "auto" if tools else "none"
806
- }
807
-
808
- await conn.session.update(session=session_update)
809
- self.connection = conn
810
- print(f"Connected with tools: {len(tools)} functions")
811
-
812
- async for event in self.connection:
813
- # Debug logging for function calls
814
- if event.type.startswith("response.function_call"):
815
- print(f"Function event: {event.type}")
816
-
817
- if event.type == "response.audio_transcript.done":
818
- await self.output_queue.put(AdditionalOutputs(event))
819
-
820
- elif event.type == "response.audio.delta":
821
- await self.output_queue.put(
822
- (
823
- self.output_sample_rate,
824
- np.frombuffer(
825
- base64.b64decode(event.delta), dtype=np.int16
826
- ).reshape(1, -1),
827
- ),
828
- )
829
-
830
- # Handle function calls
831
- elif event.type == "response.function_call_arguments.start":
832
- print(f"Function call started")
833
- self.function_call_in_progress = True
834
- self.current_function_args = ""
835
- self.current_call_id = getattr(event, 'call_id', None)
836
-
837
- elif event.type == "response.function_call_arguments.delta":
838
- if self.function_call_in_progress:
839
- self.current_function_args += event.delta
840
-
841
- elif event.type == "response.function_call_arguments.done":
842
- if self.function_call_in_progress:
843
- print(f"Function call done, args: {self.current_function_args}")
844
- try:
845
- args = json.loads(self.current_function_args)
846
- query = args.get("query", "")
847
-
848
- # Emit search event to client
849
- await self.output_queue.put(AdditionalOutputs({
850
- "type": "search",
851
- "query": query
852
- }))
853
-
854
- # Perform the search
855
- search_results = await self.search_web(query)
856
- print(f"Search results length: {len(search_results)}")
857
-
858
- # Send function result back to the model
859
- if self.connection and self.current_call_id:
860
- await self.connection.conversation.item.create(
861
- item={
862
- "type": "function_call_output",
863
- "call_id": self.current_call_id,
864
- "output": search_results
865
- }
866
- )
867
- await self.connection.response.create()
868
-
869
- except Exception as e:
870
- print(f"Function call error: {e}")
871
- finally:
872
- self.function_call_in_progress = False
873
- self.current_function_args = ""
874
- self.current_call_id = None
875
-
876
- async def receive(self, frame: tuple[int, np.ndarray]) -> None:
877
- if not self.connection:
878
- return
879
- try:
880
- _, array = frame
881
- array = array.squeeze()
882
- audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
883
- await self.connection.input_audio_buffer.append(audio=audio_message)
884
- except Exception as e:
885
- print(f"Error in receive: {e}")
886
- # Connection might be closed, ignore the error
887
-
888
- async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
889
- return await wait_for_item(self.output_queue)
890
-
891
- async def shutdown(self) -> None:
892
- if self.connection:
893
- await self.connection.close()
894
- self.connection = None
895
-
896
-
897
- # Create initial handler instance
898
- handler = OpenAIHandler(web_search_enabled=False)
899
-
900
- # Create components
901
- chatbot = gr.Chatbot(type="messages")
902
-
903
- # Create stream with handler instance
904
- stream = Stream(
905
- handler, # Pass instance, not factory
906
- mode="send-receive",
907
- modality="audio",
908
- additional_inputs=[chatbot],
909
- additional_outputs=[chatbot],
910
- additional_outputs_handler=update_chatbot,
911
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
912
- concurrency_limit=5 if get_space() else None,
913
- time_limit=300 if get_space() else None,
914
- )
915
-
916
- app = FastAPI()
917
-
918
- # Mount stream
919
- stream.mount(app)
920
-
921
- # Intercept offer to capture settings
922
- @app.post("/webrtc/offer", include_in_schema=False)
923
- async def custom_offer(request: Request):
924
- """Intercept offer to capture web search settings"""
925
- body = await request.json()
926
-
927
- webrtc_id = body.get("webrtc_id")
928
- web_search_enabled = body.get("web_search_enabled", False)
929
-
930
- print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
931
-
932
- # Store settings with timestamp
933
- if webrtc_id:
934
- web_search_settings[webrtc_id] = {
935
- 'enabled': web_search_enabled,
936
- 'timestamp': asyncio.get_event_loop().time()
937
- }
938
-
939
- # Remove our custom route temporarily
940
- custom_route = None
941
- for i, route in enumerate(app.routes):
942
- if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer:
943
- custom_route = app.routes.pop(i)
944
- break
945
-
946
- # Forward to stream's offer handler
947
- response = await stream.offer(body)
948
-
949
- # Re-add our custom route
950
- if custom_route:
951
- app.routes.insert(0, custom_route)
952
-
953
- return response
954
-
955
-
956
- @app.get("/outputs")
957
- async def outputs(webrtc_id: str):
958
- """Stream outputs including search events"""
959
- async def output_stream():
960
- async for output in stream.output_stream(webrtc_id):
961
- if hasattr(output, 'args') and output.args:
962
- # Check if it's a search event
963
- if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
964
- yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
965
- # Regular transcript event
966
- elif hasattr(output.args[0], 'transcript'):
967
- s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
968
- yield f"event: output\ndata: {s}\n\n"
969
-
970
- return StreamingResponse(output_stream(), media_type="text/event-stream")
971
-
972
-
973
- @app.get("/")
974
- async def index():
975
- """Serve the HTML page"""
976
- rtc_config = get_twilio_turn_credentials() if get_space() else None
977
- html_content = HTML_CONTENT.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
978
- return HTMLResponse(content=html_content)
979
-
980
-
981
- if __name__ == "__main__":
982
- import uvicorn
983
-
984
- mode = os.getenv("MODE")
985
- if mode == "UI":
986
- stream.ui.launch(server_port=7860)
987
- elif mode == "PHONE":
988
- stream.fastphone(host="0.0.0.0", port=7860)
989
- else:
990
- uvicorn.run(app, host="0.0.0.0", port=7860)