seawolf2357 commited on
Commit
0d8a2ef
·
verified ·
1 Parent(s): e4b0406

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -92
app.py CHANGED
@@ -20,6 +20,12 @@ from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
20
  import httpx
21
  from typing import Optional, List, Dict
22
  import gradio as gr
 
 
 
 
 
 
23
 
24
  load_dotenv()
25
 
@@ -383,6 +389,9 @@ HTML_CONTENT = """<!DOCTYPE html>
383
  let peerConnection;
384
  let webrtc_id;
385
  let webSearchEnabled = false;
 
 
 
386
  const audioOutput = document.getElementById('audio-output');
387
  const startButton = document.getElementById('start-button');
388
  const chatMessages = document.getElementById('chat-messages');
@@ -410,6 +419,7 @@ HTML_CONTENT = """<!DOCTYPE html>
410
  statusText.textContent = '연결 대기 중';
411
  }
412
  }
 
413
  function updateButtonState() {
414
  const button = document.getElementById('start-button');
415
  if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
@@ -439,6 +449,7 @@ HTML_CONTENT = """<!DOCTYPE html>
439
  updateStatus('disconnected');
440
  }
441
  }
 
442
  function setupAudioVisualization(stream) {
443
  audioContext = new (window.AudioContext || window.webkitAudioContext)();
444
  analyser = audioContext.createAnalyser();
@@ -473,6 +484,7 @@ HTML_CONTENT = """<!DOCTYPE html>
473
 
474
  updateAudioLevel();
475
  }
 
476
  function showError(message) {
477
  const toast = document.getElementById('error-toast');
478
  toast.textContent = message;
@@ -482,9 +494,46 @@ HTML_CONTENT = """<!DOCTYPE html>
482
  toast.style.display = 'none';
483
  }, 5000);
484
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
  async function setupWebRTC() {
486
  const config = __RTC_CONFIGURATION__;
487
  peerConnection = new RTCPeerConnection(config);
 
488
  const timeoutId = setTimeout(() => {
489
  const toast = document.getElementById('error-toast');
490
  toast.textContent = "연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?";
@@ -494,6 +543,7 @@ HTML_CONTENT = """<!DOCTYPE html>
494
  toast.style.display = 'none';
495
  }, 5000);
496
  }, 5000);
 
497
  try {
498
  const stream = await navigator.mediaDevices.getUserMedia({
499
  audio: true
@@ -502,21 +552,44 @@ HTML_CONTENT = """<!DOCTYPE html>
502
  stream.getTracks().forEach(track => {
503
  peerConnection.addTrack(track, stream);
504
  });
 
505
  peerConnection.addEventListener('track', (evt) => {
506
  if (audioOutput.srcObject !== evt.streams[0]) {
507
  audioOutput.srcObject = evt.streams[0];
508
  audioOutput.play();
509
  }
510
  });
 
511
  const dataChannel = peerConnection.createDataChannel('text');
 
 
 
 
 
 
 
 
 
 
 
 
512
  dataChannel.onmessage = (event) => {
513
  const eventJson = JSON.parse(event.data);
514
  if (eventJson.type === "error") {
515
  showError(eventJson.message);
 
 
516
  }
517
  };
 
 
 
 
 
 
518
  const offer = await peerConnection.createOffer();
519
  await peerConnection.setLocalDescription(offer);
 
520
  await new Promise((resolve) => {
521
  if (peerConnection.iceGatheringState === "complete") {
522
  resolve();
@@ -530,15 +603,31 @@ HTML_CONTENT = """<!DOCTYPE html>
530
  peerConnection.addEventListener("icegatheringstatechange", checkState);
531
  }
532
  });
 
 
533
  peerConnection.addEventListener('connectionstatechange', () => {
534
  console.log('connectionstatechange', peerConnection.connectionState);
535
  if (peerConnection.connectionState === 'connected') {
536
  clearTimeout(timeoutId);
537
  const toast = document.getElementById('error-toast');
538
  toast.style.display = 'none';
 
 
 
 
539
  }
540
  updateButtonState();
541
  });
 
 
 
 
 
 
 
 
 
 
542
  webrtc_id = Math.random().toString(36).substring(7);
543
  const response = await fetch('/webrtc/offer', {
544
  method: 'POST',
@@ -550,6 +639,7 @@ HTML_CONTENT = """<!DOCTYPE html>
550
  web_search_enabled: webSearchEnabled
551
  })
552
  });
 
553
  const serverResponse = await response.json();
554
  if (serverResponse.status === 'failed') {
555
  showError(serverResponse.meta.error === 'concurrency_limit_reached'
@@ -558,18 +648,27 @@ HTML_CONTENT = """<!DOCTYPE html>
558
  stop();
559
  return;
560
  }
 
561
  await peerConnection.setRemoteDescription(serverResponse);
 
562
  const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
563
  eventSource.addEventListener("output", (event) => {
564
  const eventJson = JSON.parse(event.data);
565
  addMessage("assistant", eventJson.content);
566
  });
 
567
  eventSource.addEventListener("search", (event) => {
568
  const eventJson = JSON.parse(event.data);
569
  if (eventJson.query) {
570
  addMessage("search-result", `웹 검색 중: "${eventJson.query}"`);
571
  }
572
  });
 
 
 
 
 
 
573
  } catch (err) {
574
  clearTimeout(timeoutId);
575
  console.error('Error setting up WebRTC:', err);
@@ -577,6 +676,7 @@ HTML_CONTENT = """<!DOCTYPE html>
577
  stop();
578
  }
579
  }
 
580
  function addMessage(role, content) {
581
  const messageDiv = document.createElement('div');
582
  messageDiv.classList.add('message', role);
@@ -584,7 +684,14 @@ HTML_CONTENT = """<!DOCTYPE html>
584
  chatMessages.appendChild(messageDiv);
585
  chatMessages.scrollTop = chatMessages.scrollHeight;
586
  }
 
587
  function stop() {
 
 
 
 
 
 
588
  if (animationFrame) {
589
  cancelAnimationFrame(animationFrame);
590
  }
@@ -613,6 +720,7 @@ HTML_CONTENT = """<!DOCTYPE html>
613
  updateButtonState();
614
  audioLevel = 0;
615
  }
 
616
  startButton.addEventListener('click', () => {
617
  console.log('clicked');
618
  console.log(peerConnection, peerConnection?.connectionState);
@@ -666,14 +774,14 @@ class BraveSearchClient:
666
  })
667
  return results
668
  except Exception as e:
669
- print(f"Brave Search error: {e}")
670
  return []
671
 
672
 
673
  # Initialize search client globally
674
  brave_api_key = os.getenv("BSEARCH_API")
675
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
676
- print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
677
 
678
  # Store web search settings by connection
679
  web_search_settings = {}
@@ -699,7 +807,10 @@ class OpenAIHandler(AsyncStreamHandler):
699
  self.current_call_id = None
700
  self.webrtc_id = webrtc_id
701
  self.web_search_enabled = web_search_enabled
702
- print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
 
 
 
703
 
704
  def copy(self):
705
  # Get the most recent settings
@@ -712,10 +823,10 @@ class OpenAIHandler(AsyncStreamHandler):
712
  recent_id = recent_ids[0]
713
  settings = web_search_settings[recent_id]
714
  web_search_enabled = settings.get('enabled', False)
715
- print(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}")
716
  return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id)
717
 
718
- print(f"Handler.copy() called - creating new handler with default settings")
719
  return OpenAIHandler(web_search_enabled=False)
720
 
721
  async def search_web(self, query: str) -> str:
@@ -723,7 +834,7 @@ class OpenAIHandler(AsyncStreamHandler):
723
  if not self.search_client or not self.web_search_enabled:
724
  return "웹 검색이 비활성화되어 있습니다."
725
 
726
- print(f"Searching web for: {query}")
727
  results = await self.search_client.search(query)
728
  if not results:
729
  return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
@@ -739,6 +850,31 @@ class OpenAIHandler(AsyncStreamHandler):
739
 
740
  return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  async def start_up(self):
743
  """Connect to realtime API with function calling enabled"""
744
  # First check if we have the most recent settings
@@ -751,10 +887,11 @@ class OpenAIHandler(AsyncStreamHandler):
751
  settings = web_search_settings[recent_id]
752
  self.web_search_enabled = settings.get('enabled', False)
753
  self.webrtc_id = recent_id
754
- print(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}")
755
 
756
- print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
757
  self.client = openai.AsyncOpenAI()
 
758
 
759
  # Define the web search function
760
  tools = []
@@ -778,7 +915,7 @@ class OpenAIHandler(AsyncStreamHandler):
778
  }
779
  }
780
  }]
781
- print("Web search function added to tools")
782
 
783
  instructions = (
784
  "You are a helpful assistant with web search capabilities. "
@@ -794,101 +931,136 @@ class OpenAIHandler(AsyncStreamHandler):
794
  "than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
795
  )
796
 
797
- async with self.client.beta.realtime.connect(
798
- model="gpt-4o-mini-realtime-preview-2024-12-17"
799
- ) as conn:
800
- # Update session with tools
801
- session_update = {
802
- "turn_detection": {"type": "server_vad"},
803
- "instructions": instructions,
804
- "tools": tools,
805
- "tool_choice": "auto" if tools else "none"
806
- }
807
-
808
- await conn.session.update(session=session_update)
809
- self.connection = conn
810
- print(f"Connected with tools: {len(tools)} functions")
811
-
812
- async for event in self.connection:
813
- # Debug logging for function calls
814
- if event.type.startswith("response.function_call"):
815
- print(f"Function event: {event.type}")
816
-
817
- if event.type == "response.audio_transcript.done":
818
- await self.output_queue.put(AdditionalOutputs(event))
819
-
820
- elif event.type == "response.audio.delta":
821
- await self.output_queue.put(
822
- (
823
- self.output_sample_rate,
824
- np.frombuffer(
825
- base64.b64decode(event.delta), dtype=np.int16
826
- ).reshape(1, -1),
827
- ),
828
- )
829
 
830
- # Handle function calls
831
- elif event.type == "response.function_call_arguments.start":
832
- print(f"Function call started")
833
- self.function_call_in_progress = True
834
- self.current_function_args = ""
835
- self.current_call_id = getattr(event, 'call_id', None)
836
 
837
- elif event.type == "response.function_call_arguments.delta":
838
- if self.function_call_in_progress:
839
- self.current_function_args += event.delta
840
 
841
- elif event.type == "response.function_call_arguments.done":
842
- if self.function_call_in_progress:
843
- print(f"Function call done, args: {self.current_function_args}")
844
- try:
845
- args = json.loads(self.current_function_args)
846
- query = args.get("query", "")
847
-
848
- # Emit search event to client
849
- await self.output_queue.put(AdditionalOutputs({
850
- "type": "search",
851
- "query": query
852
- }))
853
-
854
- # Perform the search
855
- search_results = await self.search_web(query)
856
- print(f"Search results length: {len(search_results)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
857
 
858
- # Send function result back to the model
859
- if self.connection and self.current_call_id:
860
- await self.connection.conversation.item.create(
861
- item={
862
- "type": "function_call_output",
863
- "call_id": self.current_call_id,
864
- "output": search_results
865
- }
866
- )
867
- await self.connection.response.create()
868
-
869
- except Exception as e:
870
- print(f"Function call error: {e}")
871
- finally:
872
- self.function_call_in_progress = False
873
- self.current_function_args = ""
874
- self.current_call_id = None
875
 
876
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
877
- if not self.connection:
878
  return
879
  try:
 
880
  _, array = frame
881
  array = array.squeeze()
882
  audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
883
  await self.connection.input_audio_buffer.append(audio=audio_message)
884
  except Exception as e:
885
- print(f"Error in receive: {e}")
886
- # Connection might be closed, ignore the error
 
 
 
 
 
 
 
 
887
 
888
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
889
  return await wait_for_item(self.output_queue)
890
 
891
  async def shutdown(self) -> None:
 
 
 
 
 
 
 
 
 
 
892
  if self.connection:
893
  await self.connection.close()
894
  self.connection = None
@@ -900,7 +1072,7 @@ handler = OpenAIHandler(web_search_enabled=False)
900
  # Create components
901
  chatbot = gr.Chatbot(type="messages")
902
 
903
- # Create stream with handler instance
904
  stream = Stream(
905
  handler, # Pass instance, not factory
906
  mode="send-receive",
@@ -910,7 +1082,7 @@ stream = Stream(
910
  additional_outputs_handler=update_chatbot,
911
  rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
912
  concurrency_limit=5 if get_space() else None,
913
- time_limit=90 if get_space() else None,
914
  )
915
 
916
  app = FastAPI()
@@ -927,7 +1099,7 @@ async def custom_offer(request: Request):
927
  webrtc_id = body.get("webrtc_id")
928
  web_search_enabled = body.get("web_search_enabled", False)
929
 
930
- print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
931
 
932
  # Store settings with timestamp
933
  if webrtc_id:
@@ -959,9 +1131,13 @@ async def outputs(webrtc_id: str):
959
  async def output_stream():
960
  async for output in stream.output_stream(webrtc_id):
961
  if hasattr(output, 'args') and output.args:
962
- # Check if it's a search event
963
- if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
964
- yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
 
 
 
 
965
  # Regular transcript event
966
  elif hasattr(output.args[0], 'transcript'):
967
  s = json.dumps({"role": "assistant", "content": output.args[0].transcript})
 
20
  import httpx
21
  from typing import Optional, List, Dict
22
  import gradio as gr
23
+ import logging
24
+ from datetime import datetime
25
+
26
+ # 로깅 설정
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
 
30
  load_dotenv()
31
 
 
389
  let peerConnection;
390
  let webrtc_id;
391
  let webSearchEnabled = false;
392
+ let reconnectAttempts = 0;
393
+ let heartbeatInterval;
394
+ let connectionHealthInterval;
395
  const audioOutput = document.getElementById('audio-output');
396
  const startButton = document.getElementById('start-button');
397
  const chatMessages = document.getElementById('chat-messages');
 
419
  statusText.textContent = '연결 대기 중';
420
  }
421
  }
422
+
423
  function updateButtonState() {
424
  const button = document.getElementById('start-button');
425
  if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
 
449
  updateStatus('disconnected');
450
  }
451
  }
452
+
453
  function setupAudioVisualization(stream) {
454
  audioContext = new (window.AudioContext || window.webkitAudioContext)();
455
  analyser = audioContext.createAnalyser();
 
484
 
485
  updateAudioLevel();
486
  }
487
+
488
  function showError(message) {
489
  const toast = document.getElementById('error-toast');
490
  toast.textContent = message;
 
494
  toast.style.display = 'none';
495
  }, 5000);
496
  }
497
+
498
+ // 연결 상태 모니터링 함수
499
+ function startConnectionHealthCheck() {
500
+ if (connectionHealthInterval) {
501
+ clearInterval(connectionHealthInterval);
502
+ }
503
+
504
+ connectionHealthInterval = setInterval(() => {
505
+ if (peerConnection) {
506
+ const state = peerConnection.connectionState;
507
+ const iceState = peerConnection.iceConnectionState;
508
+ console.log(`Connection state: ${state}, ICE state: ${iceState}`);
509
+
510
+ if (state === 'failed' || state === 'closed' || iceState === 'failed') {
511
+ console.log('Connection lost, attempting to reconnect...');
512
+ handleConnectionLoss();
513
+ }
514
+ }
515
+ }, 3000); // 3초마다 체크
516
+ }
517
+
518
+ // 연결 손실 처리
519
+ function handleConnectionLoss() {
520
+ if (reconnectAttempts < 3) {
521
+ reconnectAttempts++;
522
+ showError(`연결이 끊어졌습니다. 재연결 시도 중... (${reconnectAttempts}/3)`);
523
+ stop();
524
+ setTimeout(() => {
525
+ setupWebRTC();
526
+ }, 2000);
527
+ } else {
528
+ showError('연결을 복구할 수 없습니다. 새로고침 후 다시 시도해주세요.');
529
+ stop();
530
+ }
531
+ }
532
+
533
  async function setupWebRTC() {
534
  const config = __RTC_CONFIGURATION__;
535
  peerConnection = new RTCPeerConnection(config);
536
+
537
  const timeoutId = setTimeout(() => {
538
  const toast = document.getElementById('error-toast');
539
  toast.textContent = "연결이 평소보다 오래 걸리고 있습니다. VPN을 사용 중이신가요?";
 
543
  toast.style.display = 'none';
544
  }, 5000);
545
  }, 5000);
546
+
547
  try {
548
  const stream = await navigator.mediaDevices.getUserMedia({
549
  audio: true
 
552
  stream.getTracks().forEach(track => {
553
  peerConnection.addTrack(track, stream);
554
  });
555
+
556
  peerConnection.addEventListener('track', (evt) => {
557
  if (audioOutput.srcObject !== evt.streams[0]) {
558
  audioOutput.srcObject = evt.streams[0];
559
  audioOutput.play();
560
  }
561
  });
562
+
563
  const dataChannel = peerConnection.createDataChannel('text');
564
+
565
+ // Heartbeat 메시지 전송
566
+ dataChannel.onopen = () => {
567
+ console.log('Data channel opened');
568
+ if (heartbeatInterval) clearInterval(heartbeatInterval);
569
+ heartbeatInterval = setInterval(() => {
570
+ if (dataChannel.readyState === 'open') {
571
+ dataChannel.send(JSON.stringify({ type: 'heartbeat' }));
572
+ }
573
+ }, 30000); // 30초마다 heartbeat
574
+ };
575
+
576
  dataChannel.onmessage = (event) => {
577
  const eventJson = JSON.parse(event.data);
578
  if (eventJson.type === "error") {
579
  showError(eventJson.message);
580
+ } else if (eventJson.type === "connection_lost") {
581
+ handleConnectionLoss();
582
  }
583
  };
584
+
585
+ dataChannel.onclose = () => {
586
+ console.log('Data channel closed');
587
+ if (heartbeatInterval) clearInterval(heartbeatInterval);
588
+ };
589
+
590
  const offer = await peerConnection.createOffer();
591
  await peerConnection.setLocalDescription(offer);
592
+
593
  await new Promise((resolve) => {
594
  if (peerConnection.iceGatheringState === "complete") {
595
  resolve();
 
603
  peerConnection.addEventListener("icegatheringstatechange", checkState);
604
  }
605
  });
606
+
607
+ // 모든 연결 상태 이벤트 모니터링
608
  peerConnection.addEventListener('connectionstatechange', () => {
609
  console.log('connectionstatechange', peerConnection.connectionState);
610
  if (peerConnection.connectionState === 'connected') {
611
  clearTimeout(timeoutId);
612
  const toast = document.getElementById('error-toast');
613
  toast.style.display = 'none';
614
+ reconnectAttempts = 0;
615
+ startConnectionHealthCheck();
616
+ } else if (peerConnection.connectionState === 'failed') {
617
+ handleConnectionLoss();
618
  }
619
  updateButtonState();
620
  });
621
+
622
+ peerConnection.addEventListener('iceconnectionstatechange', () => {
623
+ console.log('ICE connection state:', peerConnection.iceConnectionState);
624
+ if (peerConnection.iceConnectionState === 'disconnected') {
625
+ showError('네트워크 연결이 불안정합니다');
626
+ } else if (peerConnection.iceConnectionState === 'failed') {
627
+ handleConnectionLoss();
628
+ }
629
+ });
630
+
631
  webrtc_id = Math.random().toString(36).substring(7);
632
  const response = await fetch('/webrtc/offer', {
633
  method: 'POST',
 
639
  web_search_enabled: webSearchEnabled
640
  })
641
  });
642
+
643
  const serverResponse = await response.json();
644
  if (serverResponse.status === 'failed') {
645
  showError(serverResponse.meta.error === 'concurrency_limit_reached'
 
648
  stop();
649
  return;
650
  }
651
+
652
  await peerConnection.setRemoteDescription(serverResponse);
653
+
654
  const eventSource = new EventSource('/outputs?webrtc_id=' + webrtc_id);
655
  eventSource.addEventListener("output", (event) => {
656
  const eventJson = JSON.parse(event.data);
657
  addMessage("assistant", eventJson.content);
658
  });
659
+
660
  eventSource.addEventListener("search", (event) => {
661
  const eventJson = JSON.parse(event.data);
662
  if (eventJson.query) {
663
  addMessage("search-result", `웹 검색 중: "${eventJson.query}"`);
664
  }
665
  });
666
+
667
+ eventSource.addEventListener("error", (event) => {
668
+ console.error('EventSource error:', event);
669
+ handleConnectionLoss();
670
+ });
671
+
672
  } catch (err) {
673
  clearTimeout(timeoutId);
674
  console.error('Error setting up WebRTC:', err);
 
676
  stop();
677
  }
678
  }
679
+
680
  function addMessage(role, content) {
681
  const messageDiv = document.createElement('div');
682
  messageDiv.classList.add('message', role);
 
684
  chatMessages.appendChild(messageDiv);
685
  chatMessages.scrollTop = chatMessages.scrollHeight;
686
  }
687
+
688
  function stop() {
689
+ if (heartbeatInterval) {
690
+ clearInterval(heartbeatInterval);
691
+ }
692
+ if (connectionHealthInterval) {
693
+ clearInterval(connectionHealthInterval);
694
+ }
695
  if (animationFrame) {
696
  cancelAnimationFrame(animationFrame);
697
  }
 
720
  updateButtonState();
721
  audioLevel = 0;
722
  }
723
+
724
  startButton.addEventListener('click', () => {
725
  console.log('clicked');
726
  console.log(peerConnection, peerConnection?.connectionState);
 
774
  })
775
  return results
776
  except Exception as e:
777
+ logger.error(f"Brave Search error: {e}")
778
  return []
779
 
780
 
781
  # Initialize search client globally
782
  brave_api_key = os.getenv("BSEARCH_API")
783
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
784
+ logger.info(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
785
 
786
  # Store web search settings by connection
787
  web_search_settings = {}
 
807
  self.current_call_id = None
808
  self.webrtc_id = webrtc_id
809
  self.web_search_enabled = web_search_enabled
810
+ self.keep_alive_task = None
811
+ self.last_activity = datetime.now()
812
+ self.connection_active = True
813
+ logger.info(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
814
 
815
  def copy(self):
816
  # Get the most recent settings
 
823
  recent_id = recent_ids[0]
824
  settings = web_search_settings[recent_id]
825
  web_search_enabled = settings.get('enabled', False)
826
+ logger.info(f"Handler.copy() using recent settings - webrtc_id={recent_id}, web_search_enabled={web_search_enabled}")
827
  return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=recent_id)
828
 
829
+ logger.info(f"Handler.copy() called - creating new handler with default settings")
830
  return OpenAIHandler(web_search_enabled=False)
831
 
832
  async def search_web(self, query: str) -> str:
 
834
  if not self.search_client or not self.web_search_enabled:
835
  return "웹 검색이 비활성화되어 있습니다."
836
 
837
+ logger.info(f"Searching web for: {query}")
838
  results = await self.search_client.search(query)
839
  if not results:
840
  return f"'{query}'에 대한 검색 결과를 찾을 수 없습니다."
 
850
 
851
  return f"웹 검색 결과 '{query}':\n\n" + "\n".join(formatted_results)
852
 
853
+ async def keep_alive(self):
854
+ """Keep the connection alive with periodic activity checks"""
855
+ while self.connection_active:
856
+ try:
857
+ await asyncio.sleep(30) # 30초마다 체크
858
+
859
+ # 마지막 활동으로부터 5분이 지났는지 확인
860
+ inactive_time = (datetime.now() - self.last_activity).total_seconds()
861
+ if inactive_time > 300: # 5분
862
+ logger.warning(f"Connection inactive for {inactive_time} seconds")
863
+
864
+ # 연결 상태 확인
865
+ if self.connection:
866
+ logger.debug("Connection alive - sending keepalive")
867
+ # OpenAI 연결은 자동으로 유지됨
868
+ else:
869
+ logger.error("Connection lost in keep_alive")
870
+ self.connection_active = False
871
+ break
872
+
873
+ except Exception as e:
874
+ logger.error(f"Keep-alive error: {e}")
875
+ self.connection_active = False
876
+ break
877
+
878
  async def start_up(self):
879
  """Connect to realtime API with function calling enabled"""
880
  # First check if we have the most recent settings
 
887
  settings = web_search_settings[recent_id]
888
  self.web_search_enabled = settings.get('enabled', False)
889
  self.webrtc_id = recent_id
890
+ logger.info(f"start_up: Updated settings from storage - webrtc_id={self.webrtc_id}, web_search_enabled={self.web_search_enabled}")
891
 
892
+ logger.info(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
893
  self.client = openai.AsyncOpenAI()
894
+ self.connection_active = True
895
 
896
  # Define the web search function
897
  tools = []
 
915
  }
916
  }
917
  }]
918
+ logger.info("Web search function added to tools")
919
 
920
  instructions = (
921
  "You are a helpful assistant with web search capabilities. "
 
931
  "than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
932
  )
933
 
934
+ try:
935
+ async with self.client.beta.realtime.connect(
936
+ model="gpt-4o-mini-realtime-preview-2024-12-17"
937
+ ) as conn:
938
+ # Update session with tools
939
+ session_update = {
940
+ "turn_detection": {"type": "server_vad"},
941
+ "instructions": instructions,
942
+ "tools": tools,
943
+ "tool_choice": "auto" if tools else "none"
944
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
945
 
946
+ await conn.session.update(session=session_update)
947
+ self.connection = conn
948
+ self.last_activity = datetime.now()
949
+ logger.info(f"Connected with tools: {len(tools)} functions")
 
 
950
 
951
+ # Start keep-alive task
952
+ self.keep_alive_task = asyncio.create_task(self.keep_alive())
 
953
 
954
+ async for event in self.connection:
955
+ self.last_activity = datetime.now()
956
+
957
+ # Debug logging for function calls
958
+ if event.type.startswith("response.function_call"):
959
+ logger.debug(f"Function event: {event.type}")
960
+
961
+ if event.type == "response.audio_transcript.done":
962
+ await self.output_queue.put(AdditionalOutputs(event))
963
+
964
+ elif event.type == "response.audio.delta":
965
+ await self.output_queue.put(
966
+ (
967
+ self.output_sample_rate,
968
+ np.frombuffer(
969
+ base64.b64decode(event.delta), dtype=np.int16
970
+ ).reshape(1, -1),
971
+ ),
972
+ )
973
+
974
+ # Handle function calls
975
+ elif event.type == "response.function_call_arguments.start":
976
+ logger.info(f"Function call started")
977
+ self.function_call_in_progress = True
978
+ self.current_function_args = ""
979
+ self.current_call_id = getattr(event, 'call_id', None)
980
+
981
+ elif event.type == "response.function_call_arguments.delta":
982
+ if self.function_call_in_progress:
983
+ self.current_function_args += event.delta
984
+
985
+ elif event.type == "response.function_call_arguments.done":
986
+ if self.function_call_in_progress:
987
+ logger.info(f"Function call done, args: {self.current_function_args}")
988
+ try:
989
+ args = json.loads(self.current_function_args)
990
+ query = args.get("query", "")
991
+
992
+ # Emit search event to client
993
+ await self.output_queue.put(AdditionalOutputs({
994
+ "type": "search",
995
+ "query": query
996
+ }))
997
+
998
+ # Perform the search
999
+ search_results = await self.search_web(query)
1000
+ logger.info(f"Search results length: {len(search_results)}")
1001
+
1002
+ # Send function result back to the model
1003
+ if self.connection and self.current_call_id:
1004
+ await self.connection.conversation.item.create(
1005
+ item={
1006
+ "type": "function_call_output",
1007
+ "call_id": self.current_call_id,
1008
+ "output": search_results
1009
+ }
1010
+ )
1011
+ await self.connection.response.create()
1012
 
1013
+ except Exception as e:
1014
+ logger.error(f"Function call error: {e}")
1015
+ finally:
1016
+ self.function_call_in_progress = False
1017
+ self.current_function_args = ""
1018
+ self.current_call_id = None
1019
+
1020
+ except Exception as e:
1021
+ logger.error(f"Connection error in start_up: {e}")
1022
+ self.connection_active = False
1023
+ # 연결 오류를 클라이언트에 알림
1024
+ await self.output_queue.put(AdditionalOutputs({
1025
+ "type": "connection_lost",
1026
+ "message": "서버 연결이 끊어졌습니다"
1027
+ }))
 
 
1028
 
1029
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
1030
+ if not self.connection or not self.connection_active:
1031
  return
1032
  try:
1033
+ self.last_activity = datetime.now()
1034
  _, array = frame
1035
  array = array.squeeze()
1036
  audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
1037
  await self.connection.input_audio_buffer.append(audio=audio_message)
1038
  except Exception as e:
1039
+ logger.error(f"Error in receive: {e}")
1040
+ # 연결이 끊어진 경우 상태 업데이트
1041
+ if "closed" in str(e).lower() or "connection" in str(e).lower():
1042
+ self.connection = None
1043
+ self.connection_active = False
1044
+ # 클라이언트에 연결 종료 알림
1045
+ await self.output_queue.put(AdditionalOutputs({
1046
+ "type": "connection_lost",
1047
+ "message": "연결이 종료되었습니다"
1048
+ }))
1049
 
1050
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
1051
  return await wait_for_item(self.output_queue)
1052
 
1053
  async def shutdown(self) -> None:
1054
+ logger.info("Shutting down handler")
1055
+ self.connection_active = False
1056
+
1057
+ if self.keep_alive_task:
1058
+ self.keep_alive_task.cancel()
1059
+ try:
1060
+ await self.keep_alive_task
1061
+ except asyncio.CancelledError:
1062
+ pass
1063
+
1064
  if self.connection:
1065
  await self.connection.close()
1066
  self.connection = None
 
1072
  # Create components
1073
  chatbot = gr.Chatbot(type="messages")
1074
 
1075
+ # Create stream with handler instance - 시간 제한 제거
1076
  stream = Stream(
1077
  handler, # Pass instance, not factory
1078
  mode="send-receive",
 
1082
  additional_outputs_handler=update_chatbot,
1083
  rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
1084
  concurrency_limit=5 if get_space() else None,
1085
+ time_limit=None, # 시간 제한 제거
1086
  )
1087
 
1088
  app = FastAPI()
 
1099
  webrtc_id = body.get("webrtc_id")
1100
  web_search_enabled = body.get("web_search_enabled", False)
1101
 
1102
+ logger.info(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
1103
 
1104
  # Store settings with timestamp
1105
  if webrtc_id:
 
1131
  async def output_stream():
1132
  async for output in stream.output_stream(webrtc_id):
1133
  if hasattr(output, 'args') and output.args:
1134
+ # Check if it's a search event or connection lost event
1135
+ if isinstance(output.args[0], dict):
1136
+ event_type = output.args[0].get('type')
1137
+ if event_type == 'search':
1138
+ yield f"event: search\ndata: {json.dumps(output.args[0])}\n\n"
1139
+ elif event_type == 'connection_lost':
1140
+ yield f"event: error\ndata: {json.dumps(output.args[0])}\n\n"
1141
  # Regular transcript event
1142
  elif hasattr(output.args[0], 'transcript'):
1143
  s = json.dumps({"role": "assistant", "content": output.args[0].transcript})