seawolf2357 commited on
Commit
6c6de8f
·
verified ·
1 Parent(s): 4d831f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -82
app.py CHANGED
@@ -670,14 +670,17 @@ class BraveSearchClient:
670
  return []
671
 
672
 
673
- # Global state for web search settings
 
 
 
 
 
674
  web_search_settings = {}
675
- current_webrtc_id = None
676
- current_web_search_enabled = False
677
 
678
 
679
  class OpenAIHandler(AsyncStreamHandler):
680
- def __init__(self) -> None:
681
  super().__init__(
682
  expected_layout="mono",
683
  output_sample_rate=SAMPLE_RATE,
@@ -686,21 +689,23 @@ class OpenAIHandler(AsyncStreamHandler):
686
  )
687
  self.connection = None
688
  self.output_queue = asyncio.Queue()
689
- self.search_client = search_client # Use global search client
690
  self.function_call_in_progress = False
691
  self.current_function_args = ""
692
  self.current_call_id = None
693
- self.webrtc_id = None
694
- self.web_search_enabled = False
695
- self.startup_called = False
696
 
697
  def copy(self):
698
- # Create a new instance that inherits settings
699
- new_handler = OpenAIHandler()
700
- new_handler.search_client = self.search_client
701
- new_handler.webrtc_id = current_webrtc_id # Use global variable
702
- new_handler.web_search_enabled = current_web_search_enabled # Use global variable
703
- return new_handler
 
 
704
 
705
  async def search_web(self, query: str) -> str:
706
  """Perform web search and return formatted results"""
@@ -725,28 +730,25 @@ class OpenAIHandler(AsyncStreamHandler):
725
 
726
  async def start_up(self):
727
  """Connect to realtime API with function calling enabled"""
728
- # Check web search setting before connecting
729
- if self.webrtc_id and self.webrtc_id in web_search_settings:
730
- self.web_search_enabled = web_search_settings[self.webrtc_id]
731
- print(f"Retrieved web_search_enabled={self.web_search_enabled} for webrtc_id={self.webrtc_id}")
732
-
733
  print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
734
  self.client = openai.AsyncOpenAI()
735
 
736
  # Define the web search function
737
  tools = []
 
 
738
  if self.web_search_enabled and self.search_client:
739
  tools = [{
740
  "type": "function",
741
  "function": {
742
  "name": "web_search",
743
- "description": "Search the web for information when the user asks questions about current events, news, weather, prices, or topics that require up-to-date information",
744
  "parameters": {
745
  "type": "object",
746
  "properties": {
747
  "query": {
748
  "type": "string",
749
- "description": "The search query"
750
  }
751
  },
752
  "required": ["query"]
@@ -754,6 +756,20 @@ class OpenAIHandler(AsyncStreamHandler):
754
  }
755
  }]
756
  print("Web search function added to tools")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
 
758
  async with self.client.beta.realtime.connect(
759
  model="gpt-4o-mini-realtime-preview-2024-12-17"
@@ -761,26 +777,20 @@ class OpenAIHandler(AsyncStreamHandler):
761
  # Update session with tools
762
  session_update = {
763
  "turn_detection": {"type": "server_vad"},
 
764
  "tools": tools,
765
  "tool_choice": "auto" if tools else "none"
766
  }
767
 
768
- # Add instructions to use web search
769
- if self.web_search_enabled:
770
- session_update["instructions"] = (
771
- "You are a helpful assistant. When users ask about current events, "
772
- "news, recent information, weather, or topics that might have changed recently, "
773
- "use the web_search function to find the most up-to-date information. "
774
- "Always search when the user asks about: 날씨(weather), 뉴스(news), 현재(current), "
775
- "최근(recent), 최신(latest), 오늘(today), or any time-sensitive information."
776
- )
777
-
778
  await conn.session.update(session=session_update)
779
  self.connection = conn
780
  print(f"Connected with tools: {len(tools)} functions")
781
- self.startup_called = True
782
 
783
  async for event in self.connection:
 
 
 
 
784
  if event.type == "response.audio_transcript.done":
785
  await self.output_queue.put(AdditionalOutputs(event))
786
 
@@ -843,10 +853,14 @@ class OpenAIHandler(AsyncStreamHandler):
843
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
844
  if not self.connection:
845
  return
846
- _, array = frame
847
- array = array.squeeze()
848
- audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
849
- await self.connection.input_audio_buffer.append(audio=audio_message)
 
 
 
 
850
 
851
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
852
  return await wait_for_item(self.output_queue)
@@ -857,24 +871,23 @@ class OpenAIHandler(AsyncStreamHandler):
857
  self.connection = None
858
 
859
 
860
- # Initialize search client
861
- brave_api_key = os.getenv("BSEARCH_API")
862
- search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
863
- print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
 
864
 
865
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
866
  chatbot.append({"role": "assistant", "content": response.transcript})
867
  return chatbot
868
 
869
- # Create chatbot component
870
- chatbot = gr.Chatbot(type="messages")
871
 
872
- # Create base handler and stream
873
- handler = OpenAIHandler()
874
- handler.search_client = search_client
875
 
 
876
  stream = Stream(
877
- handler,
878
  mode="send-receive",
879
  modality="audio",
880
  additional_inputs=[chatbot],
@@ -887,45 +900,60 @@ stream = Stream(
887
 
888
  app = FastAPI()
889
 
890
- # Mount the stream
891
- stream.mount(app)
 
 
 
 
 
 
 
 
 
 
 
892
 
893
- # Intercept the POST body before it reaches the stream
894
- from starlette.middleware.base import BaseHTTPMiddleware
895
- from starlette.requests import Request as StarletteRequest
896
- import io
897
 
898
- class WebSearchMiddleware(BaseHTTPMiddleware):
899
- async def dispatch(self, request: StarletteRequest, call_next):
900
- if request.url.path == "/webrtc/offer" and request.method == "POST":
901
- # Read the body
902
- body = await request.body()
903
-
904
- # Parse it
905
- try:
906
- data = json.loads(body)
907
- webrtc_id = data.get("webrtc_id")
908
- web_search_enabled = data.get("web_search_enabled", False)
909
-
910
- print(f"Middleware intercepted - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
911
-
912
- # Store settings globally
913
- global current_webrtc_id, current_web_search_enabled
914
- if webrtc_id:
915
- current_webrtc_id = webrtc_id
916
- current_web_search_enabled = web_search_enabled
917
- web_search_settings[webrtc_id] = web_search_enabled
918
- except:
919
- pass
920
-
921
- # Recreate the request with the body
922
- request._body = body
923
-
924
- response = await call_next(request)
925
- return response
 
 
 
 
 
 
 
 
926
 
927
- # Add middleware
928
- app.add_middleware(WebSearchMiddleware)
929
 
930
  @app.get("/outputs")
931
  async def outputs(webrtc_id: str):
 
670
  return []
671
 
672
 
673
+ # Initialize search client globally
674
+ brave_api_key = os.getenv("BSEARCH_API")
675
+ search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
676
+ print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
677
+
678
+ # Store web search settings by connection
679
  web_search_settings = {}
 
 
680
 
681
 
682
  class OpenAIHandler(AsyncStreamHandler):
683
+ def __init__(self, web_search_enabled: bool = False, webrtc_id: str = None) -> None:
684
  super().__init__(
685
  expected_layout="mono",
686
  output_sample_rate=SAMPLE_RATE,
 
689
  )
690
  self.connection = None
691
  self.output_queue = asyncio.Queue()
692
+ self.search_client = search_client
693
  self.function_call_in_progress = False
694
  self.current_function_args = ""
695
  self.current_call_id = None
696
+ self.webrtc_id = webrtc_id
697
+ self.web_search_enabled = web_search_enabled
698
+ print(f"Handler created with web_search_enabled={web_search_enabled}, webrtc_id={webrtc_id}")
699
 
700
  def copy(self):
701
+ # Check for stored settings when copying
702
+ if self.webrtc_id and self.webrtc_id in web_search_settings:
703
+ web_search_enabled = web_search_settings[self.webrtc_id]
704
+ else:
705
+ web_search_enabled = self.web_search_enabled
706
+
707
+ print(f"Handler.copy() called - creating new handler with web_search_enabled={web_search_enabled}")
708
+ return OpenAIHandler(web_search_enabled=web_search_enabled, webrtc_id=self.webrtc_id)
709
 
710
  async def search_web(self, query: str) -> str:
711
  """Perform web search and return formatted results"""
 
730
 
731
  async def start_up(self):
732
  """Connect to realtime API with function calling enabled"""
 
 
 
 
 
733
  print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
734
  self.client = openai.AsyncOpenAI()
735
 
736
  # Define the web search function
737
  tools = []
738
+ instructions = "You are a helpful assistant. Respond in Korean when the user speaks Korean."
739
+
740
  if self.web_search_enabled and self.search_client:
741
  tools = [{
742
  "type": "function",
743
  "function": {
744
  "name": "web_search",
745
+ "description": "Search the web for current information. Use this for weather, news, prices, current events, or any time-sensitive topics.",
746
  "parameters": {
747
  "type": "object",
748
  "properties": {
749
  "query": {
750
  "type": "string",
751
+ "description": "The search query in Korean or English"
752
  }
753
  },
754
  "required": ["query"]
 
756
  }
757
  }]
758
  print("Web search function added to tools")
759
+
760
+ instructions = (
761
+ "You are a helpful assistant with web search capabilities. "
762
+ "IMPORTANT: You MUST use the web_search function for ANY of these topics:\n"
763
+ "- Weather (날씨, 기온, 비, 눈)\n"
764
+ "- News (뉴스, 소식)\n"
765
+ "- Current events (현재, 최근, 오늘, 지금)\n"
766
+ "- Prices (가격, 환율, 주가)\n"
767
+ "- Sports scores or results\n"
768
+ "- Any question about 2024 or 2025\n"
769
+ "- Any time-sensitive information\n\n"
770
+ "When in doubt, USE web_search. It's better to search and provide accurate information "
771
+ "than to guess or use outdated information. Always respond in Korean when the user speaks Korean."
772
+ )
773
 
774
  async with self.client.beta.realtime.connect(
775
  model="gpt-4o-mini-realtime-preview-2024-12-17"
 
777
  # Update session with tools
778
  session_update = {
779
  "turn_detection": {"type": "server_vad"},
780
+ "instructions": instructions,
781
  "tools": tools,
782
  "tool_choice": "auto" if tools else "none"
783
  }
784
 
 
 
 
 
 
 
 
 
 
 
785
  await conn.session.update(session=session_update)
786
  self.connection = conn
787
  print(f"Connected with tools: {len(tools)} functions")
 
788
 
789
  async for event in self.connection:
790
+ # Debug logging for function calls
791
+ if event.type.startswith("response.function_call"):
792
+ print(f"Function event: {event.type}")
793
+
794
  if event.type == "response.audio_transcript.done":
795
  await self.output_queue.put(AdditionalOutputs(event))
796
 
 
853
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
854
  if not self.connection:
855
  return
856
+ try:
857
+ _, array = frame
858
+ array = array.squeeze()
859
+ audio_message = base64.b64encode(array.tobytes()).decode("utf-8")
860
+ await self.connection.input_audio_buffer.append(audio=audio_message)
861
+ except Exception as e:
862
+ print(f"Error in receive: {e}")
863
+ # Connection might be closed, ignore the error
864
 
865
  async def emit(self) -> tuple[int, np.ndarray] | AdditionalOutputs | None:
866
  return await wait_for_item(self.output_queue)
 
871
  self.connection = None
872
 
873
 
874
+ # Factory function to create handler with settings
875
+ def create_handler():
876
+ """Factory function to create handler - will be customized per connection"""
877
+ return OpenAIHandler(web_search_enabled=False)
878
+
879
 
880
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
881
  chatbot.append({"role": "assistant", "content": response.transcript})
882
  return chatbot
883
 
 
 
884
 
885
+ # Create components
886
+ chatbot = gr.Chatbot(type="messages")
 
887
 
888
+ # Create initial stream with factory function
889
  stream = Stream(
890
+ create_handler, # Pass factory function instead of instance
891
  mode="send-receive",
892
  modality="audio",
893
  additional_inputs=[chatbot],
 
900
 
901
  app = FastAPI()
902
 
903
+ # Custom handler factory that reads settings
904
+ def custom_handler_factory():
905
+ """Custom factory that checks for web search settings"""
906
+ # Try to get the most recent settings
907
+ if web_search_settings:
908
+ # Get the most recent webrtc_id
909
+ recent_id = max(web_search_settings.keys(), key=lambda k: web_search_settings[k].get('timestamp', 0))
910
+ settings = web_search_settings[recent_id]
911
+ return OpenAIHandler(
912
+ web_search_enabled=settings.get('enabled', False),
913
+ webrtc_id=recent_id
914
+ )
915
+ return OpenAIHandler(web_search_enabled=False)
916
 
917
+ # Replace the handler factory
918
+ stream.handler = custom_handler_factory
 
 
919
 
920
+ # Mount stream
921
+ stream.mount(app)
922
+
923
+ # Intercept offer to capture settings
924
+ @app.post("/webrtc/offer", include_in_schema=False)
925
+ async def custom_offer(request: Request):
926
+ """Intercept offer to capture web search settings"""
927
+ body = await request.json()
928
+
929
+ webrtc_id = body.get("webrtc_id")
930
+ web_search_enabled = body.get("web_search_enabled", False)
931
+
932
+ print(f"Custom offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
933
+
934
+ # Store settings with timestamp
935
+ if webrtc_id:
936
+ web_search_settings[webrtc_id] = {
937
+ 'enabled': web_search_enabled,
938
+ 'timestamp': asyncio.get_event_loop().time()
939
+ }
940
+
941
+ # Remove our custom route temporarily
942
+ custom_route = None
943
+ for i, route in enumerate(app.routes):
944
+ if hasattr(route, 'path') and route.path == "/webrtc/offer" and route.endpoint == custom_offer:
945
+ custom_route = app.routes.pop(i)
946
+ break
947
+
948
+ # Forward to stream's offer handler
949
+ response = await stream.offer(body)
950
+
951
+ # Re-add our custom route
952
+ if custom_route:
953
+ app.routes.insert(0, custom_route)
954
+
955
+ return response
956
 
 
 
957
 
958
  @app.get("/outputs")
959
  async def outputs(webrtc_id: str):