seawolf2357 commited on
Commit
27bbb1b
·
verified ·
1 Parent(s): a5470c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -45
app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import numpy as np
7
  import openai
8
  from dotenv import load_dotenv
9
- from fastapi import FastAPI
10
  from fastapi.responses import HTMLResponse, StreamingResponse
11
  from fastrtc import (
12
  AdditionalOutputs,
@@ -672,6 +672,8 @@ class BraveSearchClient:
672
 
673
  # Global state for web search settings
674
  web_search_settings = {}
 
 
675
 
676
 
677
  class OpenAIHandler(AsyncStreamHandler):
@@ -684,15 +686,21 @@ class OpenAIHandler(AsyncStreamHandler):
684
  )
685
  self.connection = None
686
  self.output_queue = asyncio.Queue()
687
- self.search_client = None
688
  self.function_call_in_progress = False
689
  self.current_function_args = ""
690
  self.current_call_id = None
691
  self.webrtc_id = None
692
  self.web_search_enabled = False
 
693
 
694
  def copy(self):
695
- return OpenAIHandler()
 
 
 
 
 
696
 
697
  async def search_web(self, query: str) -> str:
698
  """Perform web search and return formatted results"""
@@ -717,9 +725,10 @@ class OpenAIHandler(AsyncStreamHandler):
717
 
718
  async def start_up(self):
719
  """Connect to realtime API with function calling enabled"""
720
- # Get web search setting from global state
721
  if self.webrtc_id and self.webrtc_id in web_search_settings:
722
  self.web_search_enabled = web_search_settings[self.webrtc_id]
 
723
 
724
  print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
725
  self.client = openai.AsyncOpenAI()
@@ -731,7 +740,7 @@ class OpenAIHandler(AsyncStreamHandler):
731
  "type": "function",
732
  "function": {
733
  "name": "web_search",
734
- "description": "Search the web for information when the user asks questions about current events, news, or topics that require up-to-date information",
735
  "parameters": {
736
  "type": "object",
737
  "properties": {
@@ -760,15 +769,16 @@ class OpenAIHandler(AsyncStreamHandler):
760
  if self.web_search_enabled:
761
  session_update["instructions"] = (
762
  "You are a helpful assistant. When users ask about current events, "
763
- "news, recent information, or topics that might have changed recently, "
764
  "use the web_search function to find the most up-to-date information. "
765
- "Always search when the user asks about: weather, news, current events, "
766
- "recent updates, prices, or any time-sensitive information."
767
  )
768
 
769
  await conn.session.update(session=session_update)
770
  self.connection = conn
771
  print(f"Connected with tools: {len(tools)} functions")
 
772
 
773
  async for event in self.connection:
774
  if event.type == "response.audio_transcript.done":
@@ -876,47 +886,46 @@ stream = Stream(
876
  )
877
 
878
  app = FastAPI()
 
 
879
  stream.mount(app)
880
 
881
- # Override the offer handler to capture web search settings
882
- original_offer_handler = stream.offer
 
 
883
 
884
- async def custom_offer_handler(request):
885
- """Custom offer handler to process web search settings"""
886
- # Extract web search setting from request
887
- if hasattr(request, 'web_search_enabled'):
888
- web_search_enabled = request.web_search_enabled
889
- elif isinstance(request, dict) and 'web_search_enabled' in request:
890
- web_search_enabled = request['web_search_enabled']
891
- else:
892
- web_search_enabled = False
893
-
894
- # Extract webrtc_id
895
- if hasattr(request, 'webrtc_id'):
896
- webrtc_id = request.webrtc_id
897
- elif isinstance(request, dict) and 'webrtc_id' in request:
898
- webrtc_id = request['webrtc_id']
899
- else:
900
- webrtc_id = None
901
-
902
- print(f"Received offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
903
-
904
- # Store web search setting
905
- if webrtc_id:
906
- web_search_settings[webrtc_id] = web_search_enabled
907
- handler.webrtc_id = webrtc_id
908
- handler.web_search_enabled = web_search_enabled
909
-
910
- # Restart handler if already connected to update settings
911
- if handler.connection:
912
- await handler.shutdown()
913
- await handler.start_up()
914
-
915
- # Call original offer handler
916
- return await original_offer_handler(request)
917
 
918
- # Replace the offer method
919
- stream.offer = custom_offer_handler
920
 
921
  @app.get("/outputs")
922
  async def outputs(webrtc_id: str):
 
6
  import numpy as np
7
  import openai
8
  from dotenv import load_dotenv
9
+ from fastapi import FastAPI, Request
10
  from fastapi.responses import HTMLResponse, StreamingResponse
11
  from fastrtc import (
12
  AdditionalOutputs,
 
672
 
673
  # Global state for web search settings
674
  web_search_settings = {}
675
+ current_webrtc_id = None
676
+ current_web_search_enabled = False
677
 
678
 
679
  class OpenAIHandler(AsyncStreamHandler):
 
686
  )
687
  self.connection = None
688
  self.output_queue = asyncio.Queue()
689
+ self.search_client = search_client # Use global search client
690
  self.function_call_in_progress = False
691
  self.current_function_args = ""
692
  self.current_call_id = None
693
  self.webrtc_id = None
694
  self.web_search_enabled = False
695
+ self.startup_called = False
696
 
697
  def copy(self):
698
+ # Create a new instance that inherits settings
699
+ new_handler = OpenAIHandler()
700
+ new_handler.search_client = self.search_client
701
+ new_handler.webrtc_id = current_webrtc_id # Use global variable
702
+ new_handler.web_search_enabled = current_web_search_enabled # Use global variable
703
+ return new_handler
704
 
705
  async def search_web(self, query: str) -> str:
706
  """Perform web search and return formatted results"""
 
725
 
726
  async def start_up(self):
727
  """Connect to realtime API with function calling enabled"""
728
+ # Check web search setting before connecting
729
  if self.webrtc_id and self.webrtc_id in web_search_settings:
730
  self.web_search_enabled = web_search_settings[self.webrtc_id]
731
+ print(f"Retrieved web_search_enabled={self.web_search_enabled} for webrtc_id={self.webrtc_id}")
732
 
733
  print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
734
  self.client = openai.AsyncOpenAI()
 
740
  "type": "function",
741
  "function": {
742
  "name": "web_search",
743
+ "description": "Search the web for information when the user asks questions about current events, news, weather, prices, or topics that require up-to-date information",
744
  "parameters": {
745
  "type": "object",
746
  "properties": {
 
769
  if self.web_search_enabled:
770
  session_update["instructions"] = (
771
  "You are a helpful assistant. When users ask about current events, "
772
+ "news, recent information, weather, or topics that might have changed recently, "
773
  "use the web_search function to find the most up-to-date information. "
774
+ "Always search when the user asks about: 날씨(weather), 뉴스(news), 현재(current), "
775
+ "최근(recent), 최신(latest), 오늘(today), or any time-sensitive information."
776
  )
777
 
778
  await conn.session.update(session=session_update)
779
  self.connection = conn
780
  print(f"Connected with tools: {len(tools)} functions")
781
+ self.startup_called = True
782
 
783
  async for event in self.connection:
784
  if event.type == "response.audio_transcript.done":
 
886
  )
887
 
888
  app = FastAPI()
889
+
890
+ # Mount the stream
891
  stream.mount(app)
892
 
893
+ # Intercept the POST body before it reaches the stream
894
+ from starlette.middleware.base import BaseHTTPMiddleware
895
+ from starlette.requests import Request as StarletteRequest
896
+ import io
897
 
898
+ class WebSearchMiddleware(BaseHTTPMiddleware):
899
+ async def dispatch(self, request: StarletteRequest, call_next):
900
+ if request.url.path == "/webrtc/offer" and request.method == "POST":
901
+ # Read the body
902
+ body = await request.body()
903
+
904
+ # Parse it
905
+ try:
906
+ data = json.loads(body)
907
+ webrtc_id = data.get("webrtc_id")
908
+ web_search_enabled = data.get("web_search_enabled", False)
909
+
910
+ print(f"Middleware intercepted - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
911
+
912
+ # Store settings globally
913
+ global current_webrtc_id, current_web_search_enabled
914
+ if webrtc_id:
915
+ current_webrtc_id = webrtc_id
916
+ current_web_search_enabled = web_search_enabled
917
+ web_search_settings[webrtc_id] = web_search_enabled
918
+ except:
919
+ pass
920
+
921
+ # Recreate the request with the body
922
+ request._body = body
923
+
924
+ response = await call_next(request)
925
+ return response
 
 
 
 
 
926
 
927
+ # Add middleware
928
+ app.add_middleware(WebSearchMiddleware)
929
 
930
  @app.get("/outputs")
931
  async def outputs(webrtc_id: str):