seawolf2357 commited on
Commit
e4a5e76
ยท
verified ยท
1 Parent(s): a3f3e10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -74
app.py CHANGED
@@ -397,18 +397,7 @@ HTML_CONTENT = """<!DOCTYPE html>
397
  searchToggle.addEventListener('click', () => {
398
  webSearchEnabled = !webSearchEnabled;
399
  searchToggle.classList.toggle('active', webSearchEnabled);
400
-
401
- // Update server-side settings if connected
402
- if (webrtc_id) {
403
- fetch('/update-search-setting', {
404
- method: 'POST',
405
- headers: { 'Content-Type': 'application/json' },
406
- body: JSON.stringify({
407
- webrtc_id: webrtc_id,
408
- web_search_enabled: webSearchEnabled
409
- })
410
- });
411
- }
412
  });
413
 
414
  function updateStatus(state) {
@@ -577,7 +566,7 @@ HTML_CONTENT = """<!DOCTYPE html>
577
  });
578
  eventSource.addEventListener("search", (event) => {
579
  const eventJson = JSON.parse(event.data);
580
- if (eventJson.results) {
581
  addMessage("search-result", `์›น ๊ฒ€์ƒ‰ ์ค‘: "${eventJson.query}"`);
582
  }
583
  });
@@ -682,7 +671,7 @@ class BraveSearchClient:
682
 
683
 
684
  class OpenAIHandler(AsyncStreamHandler):
685
- def __init__(self) -> None:
686
  super().__init__(
687
  expected_layout="mono",
688
  output_sample_rate=SAMPLE_RATE,
@@ -691,20 +680,23 @@ class OpenAIHandler(AsyncStreamHandler):
691
  )
692
  self.connection = None
693
  self.output_queue = asyncio.Queue()
694
- self.web_search_enabled = False
695
- self.search_client = None
696
  self.function_call_in_progress = False
697
  self.current_function_args = ""
698
- self.call_id = None
 
699
 
700
  def copy(self):
701
- return OpenAIHandler()
 
702
 
703
  async def search_web(self, query: str) -> str:
704
  """Perform web search and return formatted results"""
705
  if not self.search_client or not self.web_search_enabled:
706
  return "์›น ๊ฒ€์ƒ‰์ด ๋น„ํ™œ์„ฑํ™”๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค."
707
 
 
708
  results = await self.search_client.search(query)
709
  if not results:
710
  return f"'{query}'์— ๋Œ€ํ•œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
@@ -722,6 +714,7 @@ class OpenAIHandler(AsyncStreamHandler):
722
 
723
  async def start_up(self):
724
  """Connect to realtime API with function calling enabled"""
 
725
  self.client = openai.AsyncOpenAI()
726
 
727
  # Define the web search function
@@ -744,6 +737,7 @@ class OpenAIHandler(AsyncStreamHandler):
744
  }
745
  }
746
  }]
 
747
 
748
  async with self.client.beta.realtime.connect(
749
  model="gpt-4o-mini-realtime-preview-2024-12-17"
@@ -757,6 +751,7 @@ class OpenAIHandler(AsyncStreamHandler):
757
 
758
  await conn.session.update(session=session_update)
759
  self.connection = conn
 
760
 
761
  async for event in self.connection:
762
  if event.type == "response.audio_transcript.done":
@@ -774,9 +769,10 @@ class OpenAIHandler(AsyncStreamHandler):
774
 
775
  # Handle function calls
776
  elif event.type == "response.function_call_arguments.start":
 
777
  self.function_call_in_progress = True
778
  self.current_function_args = ""
779
- self.call_id = event.call_id if hasattr(event, 'call_id') else None
780
 
781
  elif event.type == "response.function_call_arguments.delta":
782
  if self.function_call_in_progress:
@@ -784,6 +780,7 @@ class OpenAIHandler(AsyncStreamHandler):
784
 
785
  elif event.type == "response.function_call_arguments.done":
786
  if self.function_call_in_progress:
 
787
  try:
788
  args = json.loads(self.current_function_args)
789
  query = args.get("query", "")
@@ -791,19 +788,19 @@ class OpenAIHandler(AsyncStreamHandler):
791
  # Emit search event to client
792
  await self.output_queue.put(AdditionalOutputs({
793
  "type": "search",
794
- "query": query,
795
- "results": True
796
  }))
797
 
798
  # Perform the search
799
  search_results = await self.search_web(query)
 
800
 
801
  # Send function result back to the model
802
- if self.connection and self.call_id:
803
  await self.connection.conversation.item.create(
804
  item={
805
  "type": "function_call_output",
806
- "call_id": self.call_id,
807
  "output": search_results
808
  }
809
  )
@@ -814,7 +811,7 @@ class OpenAIHandler(AsyncStreamHandler):
814
  finally:
815
  self.function_call_in_progress = False
816
  self.current_function_args = ""
817
- self.call_id = None
818
 
819
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
820
  if not self.connection:
@@ -833,80 +830,77 @@ class OpenAIHandler(AsyncStreamHandler):
833
  self.connection = None
834
 
835
 
836
- # Store connection settings
837
- connection_settings = {}
838
 
839
  # Initialize search client
840
  brave_api_key = os.getenv("BSEARCH_API")
841
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
 
842
 
843
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
844
  chatbot.append({"role": "assistant", "content": response.transcript})
845
  return chatbot
846
 
847
- # Create chatbot component
848
- chatbot = gr.Chatbot(type="messages")
849
-
850
- # Create handler
851
- handler = OpenAIHandler()
852
- handler.search_client = search_client
853
-
854
- # Create stream
855
- stream = Stream(
856
- handler,
857
- mode="send-receive",
858
- modality="audio",
859
- additional_inputs=[chatbot],
860
- additional_outputs=[chatbot],
861
- additional_outputs_handler=update_chatbot,
862
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
863
- concurrency_limit=5 if get_space() else None,
864
- time_limit=90 if get_space() else None,
865
- )
866
 
 
867
  app = FastAPI()
868
- stream.mount(app)
869
 
870
- @app.post("/update-search-setting")
871
- async def update_search_setting(request: dict):
872
- """Update web search setting for a connection"""
873
  webrtc_id = request.get("webrtc_id")
874
  web_search_enabled = request.get("web_search_enabled", False)
875
 
876
- if webrtc_id:
877
- connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
878
- # Update handler setting
879
- handler.web_search_enabled = web_search_enabled
880
-
881
- # Reconnect with updated settings if needed
882
- if handler.connection:
883
- await handler.shutdown()
884
- await handler.start_up()
885
 
886
- return {"status": "updated"}
887
-
888
- # Override the stream's offer handler
889
- original_offer = stream.offer
890
-
891
- async def custom_offer(request: dict):
892
- """Handle WebRTC offer with web search preference"""
893
- webrtc_id = request.get("webrtc_id")
894
- web_search_enabled = request.get("web_search_enabled", False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
 
896
- if webrtc_id:
897
- connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
898
- handler.web_search_enabled = web_search_enabled
899
 
900
- return await original_offer(request)
 
 
 
 
 
 
901
 
902
- # Replace the offer method
903
- stream.offer = custom_offer
904
 
905
  @app.get("/outputs")
906
  async def outputs(webrtc_id: str):
907
  """Stream outputs including search events"""
908
  async def output_stream():
909
- async for output in stream.output_stream(webrtc_id):
 
 
 
 
910
  if hasattr(output, 'args') and output.args:
911
  # Check if it's a search event
912
  if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
@@ -932,8 +926,38 @@ if __name__ == "__main__":
932
 
933
  mode = os.getenv("MODE")
934
  if mode == "UI":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  stream.ui.launch(server_port=7860)
936
  elif mode == "PHONE":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937
  stream.fastphone(host="0.0.0.0", port=7860)
938
  else:
939
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
397
  searchToggle.addEventListener('click', () => {
398
  webSearchEnabled = !webSearchEnabled;
399
  searchToggle.classList.toggle('active', webSearchEnabled);
400
+ console.log('Web search enabled:', webSearchEnabled);
 
 
 
 
 
 
 
 
 
 
 
401
  });
402
 
403
  function updateStatus(state) {
 
566
  });
567
  eventSource.addEventListener("search", (event) => {
568
  const eventJson = JSON.parse(event.data);
569
+ if (eventJson.query) {
570
  addMessage("search-result", `์›น ๊ฒ€์ƒ‰ ์ค‘: "${eventJson.query}"`);
571
  }
572
  });
 
671
 
672
 
673
  class OpenAIHandler(AsyncStreamHandler):
674
+ def __init__(self, web_search_enabled: bool = False, search_client: Optional[BraveSearchClient] = None, webrtc_id: str = None) -> None:
675
  super().__init__(
676
  expected_layout="mono",
677
  output_sample_rate=SAMPLE_RATE,
 
680
  )
681
  self.connection = None
682
  self.output_queue = asyncio.Queue()
683
+ self.web_search_enabled = web_search_enabled
684
+ self.search_client = search_client
685
  self.function_call_in_progress = False
686
  self.current_function_args = ""
687
+ self.current_call_id = None
688
+ self.webrtc_id = webrtc_id
689
 
690
  def copy(self):
691
+ # Return a new instance with the same settings
692
+ return OpenAIHandler(self.web_search_enabled, self.search_client, self.webrtc_id)
693
 
694
  async def search_web(self, query: str) -> str:
695
  """Perform web search and return formatted results"""
696
  if not self.search_client or not self.web_search_enabled:
697
  return "์›น ๊ฒ€์ƒ‰์ด ๋น„ํ™œ์„ฑํ™”๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค."
698
 
699
+ print(f"Searching web for: {query}")
700
  results = await self.search_client.search(query)
701
  if not results:
702
  return f"'{query}'์— ๋Œ€ํ•œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
 
714
 
715
  async def start_up(self):
716
  """Connect to realtime API with function calling enabled"""
717
+ print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
718
  self.client = openai.AsyncOpenAI()
719
 
720
  # Define the web search function
 
737
  }
738
  }
739
  }]
740
+ print("Web search function added to tools")
741
 
742
  async with self.client.beta.realtime.connect(
743
  model="gpt-4o-mini-realtime-preview-2024-12-17"
 
751
 
752
  await conn.session.update(session=session_update)
753
  self.connection = conn
754
+ print(f"Connected with tools: {len(tools)} functions")
755
 
756
  async for event in self.connection:
757
  if event.type == "response.audio_transcript.done":
 
769
 
770
  # Handle function calls
771
  elif event.type == "response.function_call_arguments.start":
772
+ print(f"Function call started, call_id: {getattr(event, 'call_id', 'unknown')}")
773
  self.function_call_in_progress = True
774
  self.current_function_args = ""
775
+ self.current_call_id = getattr(event, 'call_id', None)
776
 
777
  elif event.type == "response.function_call_arguments.delta":
778
  if self.function_call_in_progress:
 
780
 
781
  elif event.type == "response.function_call_arguments.done":
782
  if self.function_call_in_progress:
783
+ print(f"Function call done, args: {self.current_function_args}")
784
  try:
785
  args = json.loads(self.current_function_args)
786
  query = args.get("query", "")
 
788
  # Emit search event to client
789
  await self.output_queue.put(AdditionalOutputs({
790
  "type": "search",
791
+ "query": query
 
792
  }))
793
 
794
  # Perform the search
795
  search_results = await self.search_web(query)
796
+ print(f"Search results length: {len(search_results)}")
797
 
798
  # Send function result back to the model
799
+ if self.connection and self.current_call_id:
800
  await self.connection.conversation.item.create(
801
  item={
802
  "type": "function_call_output",
803
+ "call_id": self.current_call_id,
804
  "output": search_results
805
  }
806
  )
 
811
  finally:
812
  self.function_call_in_progress = False
813
  self.current_function_args = ""
814
+ self.current_call_id = None
815
 
816
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
817
  if not self.connection:
 
830
  self.connection = None
831
 
832
 
833
+ # Store handlers by webrtc_id
834
+ handlers = {}
835
 
836
  # Initialize search client
837
  brave_api_key = os.getenv("BSEARCH_API")
838
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
839
+ print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
840
 
841
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
842
  chatbot.append({"role": "assistant", "content": response.transcript})
843
  return chatbot
844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
845
 
846
+ # Create base stream without handler for mounting
847
  app = FastAPI()
 
848
 
849
+ @app.post("/webrtc/offer")
850
+ async def handle_offer(request: dict):
851
+ """Handle WebRTC offer and create a new handler for this connection"""
852
  webrtc_id = request.get("webrtc_id")
853
  web_search_enabled = request.get("web_search_enabled", False)
854
 
855
+ print(f"Received offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
 
 
 
 
 
 
 
 
856
 
857
+ # Create a new handler for this connection
858
+ handler = OpenAIHandler(
859
+ web_search_enabled=web_search_enabled,
860
+ search_client=search_client,
861
+ webrtc_id=webrtc_id
862
+ )
863
+
864
+ # Store the handler
865
+ handlers[webrtc_id] = handler
866
+
867
+ # Create chatbot instance for this connection
868
+ chatbot = gr.Chatbot(type="messages")
869
+
870
+ # Create stream for this connection
871
+ stream = Stream(
872
+ handler,
873
+ mode="send-receive",
874
+ modality="audio",
875
+ additional_inputs=[chatbot],
876
+ additional_outputs=[chatbot],
877
+ additional_outputs_handler=update_chatbot,
878
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
879
+ concurrency_limit=5 if get_space() else None,
880
+ time_limit=90 if get_space() else None,
881
+ )
882
 
883
+ # Mount stream temporarily
884
+ stream.mount(app)
 
885
 
886
+ # Store stream reference
887
+ handler.stream = stream
888
+
889
+ # Process the offer
890
+ response = await stream.offer(request)
891
+
892
+ return response
893
 
 
 
894
 
895
  @app.get("/outputs")
896
  async def outputs(webrtc_id: str):
897
  """Stream outputs including search events"""
898
  async def output_stream():
899
+ handler = handlers.get(webrtc_id)
900
+ if not handler or not hasattr(handler, 'stream'):
901
+ return
902
+
903
+ async for output in handler.stream.output_stream(webrtc_id):
904
  if hasattr(output, 'args') and output.args:
905
  # Check if it's a search event
906
  if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
 
926
 
927
  mode = os.getenv("MODE")
928
  if mode == "UI":
929
+ # For UI mode, we need a base handler and stream
930
+ base_handler = OpenAIHandler()
931
+ chatbot = gr.Chatbot(type="messages")
932
+ stream = Stream(
933
+ base_handler,
934
+ mode="send-receive",
935
+ modality="audio",
936
+ additional_inputs=[chatbot],
937
+ additional_outputs=[chatbot],
938
+ additional_outputs_handler=update_chatbot,
939
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
940
+ concurrency_limit=5 if get_space() else None,
941
+ time_limit=90 if get_space() else None,
942
+ )
943
+ stream.mount(app)
944
  stream.ui.launch(server_port=7860)
945
  elif mode == "PHONE":
946
+ # Similar for phone mode
947
+ base_handler = OpenAIHandler()
948
+ chatbot = gr.Chatbot(type="messages")
949
+ stream = Stream(
950
+ base_handler,
951
+ mode="send-receive",
952
+ modality="audio",
953
+ additional_inputs=[chatbot],
954
+ additional_outputs=[chatbot],
955
+ additional_outputs_handler=update_chatbot,
956
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
957
+ concurrency_limit=5 if get_space() else None,
958
+ time_limit=90 if get_space() else None,
959
+ )
960
+ stream.mount(app)
961
  stream.fastphone(host="0.0.0.0", port=7860)
962
  else:
963
  uvicorn.run(app, host="0.0.0.0", port=7860)