seawolf2357 commited on
Commit
a5470c9
·
verified ·
1 Parent(s): e4a5e76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -87
app.py CHANGED
@@ -670,8 +670,12 @@ class BraveSearchClient:
670
  return []
671
 
672
 
 
 
 
 
673
  class OpenAIHandler(AsyncStreamHandler):
674
- def __init__(self, web_search_enabled: bool = False, search_client: Optional[BraveSearchClient] = None, webrtc_id: str = None) -> None:
675
  super().__init__(
676
  expected_layout="mono",
677
  output_sample_rate=SAMPLE_RATE,
@@ -680,16 +684,15 @@ class OpenAIHandler(AsyncStreamHandler):
680
  )
681
  self.connection = None
682
  self.output_queue = asyncio.Queue()
683
- self.web_search_enabled = web_search_enabled
684
- self.search_client = search_client
685
  self.function_call_in_progress = False
686
  self.current_function_args = ""
687
  self.current_call_id = None
688
- self.webrtc_id = webrtc_id
 
689
 
690
  def copy(self):
691
- # Return a new instance with the same settings
692
- return OpenAIHandler(self.web_search_enabled, self.search_client, self.webrtc_id)
693
 
694
  async def search_web(self, query: str) -> str:
695
  """Perform web search and return formatted results"""
@@ -714,6 +717,10 @@ class OpenAIHandler(AsyncStreamHandler):
714
 
715
  async def start_up(self):
716
  """Connect to realtime API with function calling enabled"""
 
 
 
 
717
  print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
718
  self.client = openai.AsyncOpenAI()
719
 
@@ -724,7 +731,7 @@ class OpenAIHandler(AsyncStreamHandler):
724
  "type": "function",
725
  "function": {
726
  "name": "web_search",
727
- "description": "Search the web for information",
728
  "parameters": {
729
  "type": "object",
730
  "properties": {
@@ -749,6 +756,16 @@ class OpenAIHandler(AsyncStreamHandler):
749
  "tool_choice": "auto" if tools else "none"
750
  }
751
 
 
 
 
 
 
 
 
 
 
 
752
  await conn.session.update(session=session_update)
753
  self.connection = conn
754
  print(f"Connected with tools: {len(tools)} functions")
@@ -769,7 +786,7 @@ class OpenAIHandler(AsyncStreamHandler):
769
 
770
  # Handle function calls
771
  elif event.type == "response.function_call_arguments.start":
772
- print(f"Function call started, call_id: {getattr(event, 'call_id', 'unknown')}")
773
  self.function_call_in_progress = True
774
  self.current_function_args = ""
775
  self.current_call_id = getattr(event, 'call_id', None)
@@ -830,9 +847,6 @@ class OpenAIHandler(AsyncStreamHandler):
830
  self.connection = None
831
 
832
 
833
- # Store handlers by webrtc_id
834
- handlers = {}
835
-
836
  # Initialize search client
837
  brave_api_key = os.getenv("BSEARCH_API")
838
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
@@ -842,65 +856,73 @@ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEve
842
  chatbot.append({"role": "assistant", "content": response.transcript})
843
  return chatbot
844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
845
 
846
- # Create base stream without handler for mounting
847
  app = FastAPI()
 
848
 
849
- @app.post("/webrtc/offer")
850
- async def handle_offer(request: dict):
851
- """Handle WebRTC offer and create a new handler for this connection"""
852
- webrtc_id = request.get("webrtc_id")
853
- web_search_enabled = request.get("web_search_enabled", False)
854
-
855
- print(f"Received offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
856
-
857
- # Create a new handler for this connection
858
- handler = OpenAIHandler(
859
- web_search_enabled=web_search_enabled,
860
- search_client=search_client,
861
- webrtc_id=webrtc_id
862
- )
863
-
864
- # Store the handler
865
- handlers[webrtc_id] = handler
866
-
867
- # Create chatbot instance for this connection
868
- chatbot = gr.Chatbot(type="messages")
869
-
870
- # Create stream for this connection
871
- stream = Stream(
872
- handler,
873
- mode="send-receive",
874
- modality="audio",
875
- additional_inputs=[chatbot],
876
- additional_outputs=[chatbot],
877
- additional_outputs_handler=update_chatbot,
878
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
879
- concurrency_limit=5 if get_space() else None,
880
- time_limit=90 if get_space() else None,
881
- )
882
 
883
- # Mount stream temporarily
884
- stream.mount(app)
 
 
 
 
 
885
 
886
- # Store stream reference
887
- handler.stream = stream
888
 
889
- # Process the offer
890
- response = await stream.offer(request)
 
 
 
 
 
 
 
 
891
 
892
- return response
 
893
 
 
 
894
 
895
  @app.get("/outputs")
896
  async def outputs(webrtc_id: str):
897
  """Stream outputs including search events"""
898
  async def output_stream():
899
- handler = handlers.get(webrtc_id)
900
- if not handler or not hasattr(handler, 'stream'):
901
- return
902
-
903
- async for output in handler.stream.output_stream(webrtc_id):
904
  if hasattr(output, 'args') and output.args:
905
  # Check if it's a search event
906
  if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
@@ -926,38 +948,8 @@ if __name__ == "__main__":
926
 
927
  mode = os.getenv("MODE")
928
  if mode == "UI":
929
- # For UI mode, we need a base handler and stream
930
- base_handler = OpenAIHandler()
931
- chatbot = gr.Chatbot(type="messages")
932
- stream = Stream(
933
- base_handler,
934
- mode="send-receive",
935
- modality="audio",
936
- additional_inputs=[chatbot],
937
- additional_outputs=[chatbot],
938
- additional_outputs_handler=update_chatbot,
939
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
940
- concurrency_limit=5 if get_space() else None,
941
- time_limit=90 if get_space() else None,
942
- )
943
- stream.mount(app)
944
  stream.ui.launch(server_port=7860)
945
  elif mode == "PHONE":
946
- # Similar for phone mode
947
- base_handler = OpenAIHandler()
948
- chatbot = gr.Chatbot(type="messages")
949
- stream = Stream(
950
- base_handler,
951
- mode="send-receive",
952
- modality="audio",
953
- additional_inputs=[chatbot],
954
- additional_outputs=[chatbot],
955
- additional_outputs_handler=update_chatbot,
956
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
957
- concurrency_limit=5 if get_space() else None,
958
- time_limit=90 if get_space() else None,
959
- )
960
- stream.mount(app)
961
  stream.fastphone(host="0.0.0.0", port=7860)
962
  else:
963
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
670
  return []
671
 
672
 
673
+ # Global state for web search settings
674
+ web_search_settings = {}
675
+
676
+
677
  class OpenAIHandler(AsyncStreamHandler):
678
+ def __init__(self) -> None:
679
  super().__init__(
680
  expected_layout="mono",
681
  output_sample_rate=SAMPLE_RATE,
 
684
  )
685
  self.connection = None
686
  self.output_queue = asyncio.Queue()
687
+ self.search_client = None
 
688
  self.function_call_in_progress = False
689
  self.current_function_args = ""
690
  self.current_call_id = None
691
+ self.webrtc_id = None
692
+ self.web_search_enabled = False
693
 
694
  def copy(self):
695
+ return OpenAIHandler()
 
696
 
697
  async def search_web(self, query: str) -> str:
698
  """Perform web search and return formatted results"""
 
717
 
718
  async def start_up(self):
719
  """Connect to realtime API with function calling enabled"""
720
+ # Get web search setting from global state
721
+ if self.webrtc_id and self.webrtc_id in web_search_settings:
722
+ self.web_search_enabled = web_search_settings[self.webrtc_id]
723
+
724
  print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
725
  self.client = openai.AsyncOpenAI()
726
 
 
731
  "type": "function",
732
  "function": {
733
  "name": "web_search",
734
+ "description": "Search the web for information when the user asks questions about current events, news, or topics that require up-to-date information",
735
  "parameters": {
736
  "type": "object",
737
  "properties": {
 
756
  "tool_choice": "auto" if tools else "none"
757
  }
758
 
759
+ # Add instructions to use web search
760
+ if self.web_search_enabled:
761
+ session_update["instructions"] = (
762
+ "You are a helpful assistant. When users ask about current events, "
763
+ "news, recent information, or topics that might have changed recently, "
764
+ "use the web_search function to find the most up-to-date information. "
765
+ "Always search when the user asks about: weather, news, current events, "
766
+ "recent updates, prices, or any time-sensitive information."
767
+ )
768
+
769
  await conn.session.update(session=session_update)
770
  self.connection = conn
771
  print(f"Connected with tools: {len(tools)} functions")
 
786
 
787
  # Handle function calls
788
  elif event.type == "response.function_call_arguments.start":
789
+ print(f"Function call started")
790
  self.function_call_in_progress = True
791
  self.current_function_args = ""
792
  self.current_call_id = getattr(event, 'call_id', None)
 
847
  self.connection = None
848
 
849
 
 
 
 
850
  # Initialize search client
851
  brave_api_key = os.getenv("BSEARCH_API")
852
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
 
856
  chatbot.append({"role": "assistant", "content": response.transcript})
857
  return chatbot
858
 
859
+ # Create chatbot component
860
+ chatbot = gr.Chatbot(type="messages")
861
+
862
+ # Create base handler and stream
863
+ handler = OpenAIHandler()
864
+ handler.search_client = search_client
865
+
866
+ stream = Stream(
867
+ handler,
868
+ mode="send-receive",
869
+ modality="audio",
870
+ additional_inputs=[chatbot],
871
+ additional_outputs=[chatbot],
872
+ additional_outputs_handler=update_chatbot,
873
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
874
+ concurrency_limit=5 if get_space() else None,
875
+ time_limit=90 if get_space() else None,
876
+ )
877
 
 
878
  app = FastAPI()
879
+ stream.mount(app)
880
 
881
+ # Override the offer handler to capture web search settings
882
+ original_offer_handler = stream.offer
883
+
884
+ async def custom_offer_handler(request):
885
+ """Custom offer handler to process web search settings"""
886
+ # Extract web search setting from request
887
+ if hasattr(request, 'web_search_enabled'):
888
+ web_search_enabled = request.web_search_enabled
889
+ elif isinstance(request, dict) and 'web_search_enabled' in request:
890
+ web_search_enabled = request['web_search_enabled']
891
+ else:
892
+ web_search_enabled = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893
 
894
+ # Extract webrtc_id
895
+ if hasattr(request, 'webrtc_id'):
896
+ webrtc_id = request.webrtc_id
897
+ elif isinstance(request, dict) and 'webrtc_id' in request:
898
+ webrtc_id = request['webrtc_id']
899
+ else:
900
+ webrtc_id = None
901
 
902
+ print(f"Received offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
 
903
 
904
+ # Store web search setting
905
+ if webrtc_id:
906
+ web_search_settings[webrtc_id] = web_search_enabled
907
+ handler.webrtc_id = webrtc_id
908
+ handler.web_search_enabled = web_search_enabled
909
+
910
+ # Restart handler if already connected to update settings
911
+ if handler.connection:
912
+ await handler.shutdown()
913
+ await handler.start_up()
914
 
915
+ # Call original offer handler
916
+ return await original_offer_handler(request)
917
 
918
+ # Replace the offer method
919
+ stream.offer = custom_offer_handler
920
 
921
  @app.get("/outputs")
922
  async def outputs(webrtc_id: str):
923
  """Stream outputs including search events"""
924
  async def output_stream():
925
+ async for output in stream.output_stream(webrtc_id):
 
 
 
 
926
  if hasattr(output, 'args') and output.args:
927
  # Check if it's a search event
928
  if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
 
948
 
949
  mode = os.getenv("MODE")
950
  if mode == "UI":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
  stream.ui.launch(server_port=7860)
952
  elif mode == "PHONE":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  stream.fastphone(host="0.0.0.0", port=7860)
954
  else:
955
  uvicorn.run(app, host="0.0.0.0", port=7860)