seawolf2357 commited on
Commit
a3f3e10
·
verified ·
1 Parent(s): c4fe35d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -43
app.py CHANGED
@@ -19,6 +19,7 @@ from gradio.utils import get_space
19
  from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
20
  import httpx
21
  from typing import Optional, List, Dict
 
22
 
23
  load_dotenv()
24
 
@@ -396,6 +397,18 @@ HTML_CONTENT = """<!DOCTYPE html>
396
  searchToggle.addEventListener('click', () => {
397
  webSearchEnabled = !webSearchEnabled;
398
  searchToggle.classList.toggle('active', webSearchEnabled);
 
 
 
 
 
 
 
 
 
 
 
 
399
  });
400
 
401
  function updateStatus(state) {
@@ -669,7 +682,7 @@ class BraveSearchClient:
669
 
670
 
671
  class OpenAIHandler(AsyncStreamHandler):
672
- def __init__(self, web_search_enabled: bool = False, search_client: Optional[BraveSearchClient] = None) -> None:
673
  super().__init__(
674
  expected_layout="mono",
675
  output_sample_rate=SAMPLE_RATE,
@@ -678,13 +691,14 @@ class OpenAIHandler(AsyncStreamHandler):
678
  )
679
  self.connection = None
680
  self.output_queue = asyncio.Queue()
681
- self.web_search_enabled = web_search_enabled
682
- self.search_client = search_client
683
  self.function_call_in_progress = False
684
  self.current_function_args = ""
 
685
 
686
  def copy(self):
687
- return OpenAIHandler(self.web_search_enabled, self.search_client)
688
 
689
  async def search_web(self, query: str) -> str:
690
  """Perform web search and return formatted results"""
@@ -762,6 +776,7 @@ class OpenAIHandler(AsyncStreamHandler):
762
  elif event.type == "response.function_call_arguments.start":
763
  self.function_call_in_progress = True
764
  self.current_function_args = ""
 
765
 
766
  elif event.type == "response.function_call_arguments.delta":
767
  if self.function_call_in_progress:
@@ -784,11 +799,11 @@ class OpenAIHandler(AsyncStreamHandler):
784
  search_results = await self.search_web(query)
785
 
786
  # Send function result back to the model
787
- if self.connection:
788
  await self.connection.conversation.item.create(
789
  item={
790
  "type": "function_call_output",
791
- "call_id": event.call_id,
792
  "output": search_results
793
  }
794
  )
@@ -799,6 +814,7 @@ class OpenAIHandler(AsyncStreamHandler):
799
  finally:
800
  self.function_call_in_progress = False
801
  self.current_function_args = ""
 
802
 
803
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
804
  if not self.connection:
@@ -817,63 +833,80 @@ class OpenAIHandler(AsyncStreamHandler):
817
  self.connection = None
818
 
819
 
820
- # Store active handlers by webrtc_id
821
- active_handlers = {}
822
 
823
  # Initialize search client
824
  brave_api_key = os.getenv("BSEARCH_API")
825
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
826
 
827
- app = FastAPI()
828
-
829
-
830
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
831
  chatbot.append({"role": "assistant", "content": response.transcript})
832
  return chatbot
833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
 
835
- @app.post("/webrtc/offer")
836
- async def webrtc_offer(request: dict):
837
- """Handle WebRTC offer with web search preference"""
838
- web_search_enabled = request.get("web_search_enabled", False)
 
 
839
  webrtc_id = request.get("webrtc_id")
 
840
 
841
- # Create handler with web search capability
842
- handler = OpenAIHandler(web_search_enabled=web_search_enabled, search_client=search_client)
843
- active_handlers[webrtc_id] = handler
844
-
845
- # Create stream for this connection
846
- stream = Stream(
847
- handler,
848
- mode="send-receive",
849
- modality="audio",
850
- additional_inputs=[[]], # Empty chatbot state
851
- additional_outputs=[[]],
852
- additional_outputs_handler=update_chatbot,
853
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
854
- concurrency_limit=5 if get_space() else None,
855
- time_limit=90 if get_space() else None,
856
- )
857
 
858
- # Store stream reference
859
- handler.stream = stream
 
 
 
 
 
 
 
860
 
861
- # Mount and handle offer
862
- stream.mount(app)
 
863
 
864
- # Forward the WebRTC offer to the stream
865
- return await stream.offer(request)
866
 
 
 
867
 
868
  @app.get("/outputs")
869
  async def outputs(webrtc_id: str):
870
  """Stream outputs including search events"""
871
  async def output_stream():
872
- handler = active_handlers.get(webrtc_id)
873
- if not handler or not hasattr(handler, 'stream'):
874
- return
875
-
876
- async for output in handler.stream.output_stream(webrtc_id):
877
  if hasattr(output, 'args') and output.args:
878
  # Check if it's a search event
879
  if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
@@ -896,4 +929,11 @@ async def index():
896
 
897
  if __name__ == "__main__":
898
  import uvicorn
899
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
19
  from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
20
  import httpx
21
  from typing import Optional, List, Dict
22
+ import gradio as gr
23
 
24
  load_dotenv()
25
 
 
397
  searchToggle.addEventListener('click', () => {
398
  webSearchEnabled = !webSearchEnabled;
399
  searchToggle.classList.toggle('active', webSearchEnabled);
400
+
401
+ // Update server-side settings if connected
402
+ if (webrtc_id) {
403
+ fetch('/update-search-setting', {
404
+ method: 'POST',
405
+ headers: { 'Content-Type': 'application/json' },
406
+ body: JSON.stringify({
407
+ webrtc_id: webrtc_id,
408
+ web_search_enabled: webSearchEnabled
409
+ })
410
+ });
411
+ }
412
  });
413
 
414
  function updateStatus(state) {
 
682
 
683
 
684
  class OpenAIHandler(AsyncStreamHandler):
685
+ def __init__(self) -> None:
686
  super().__init__(
687
  expected_layout="mono",
688
  output_sample_rate=SAMPLE_RATE,
 
691
  )
692
  self.connection = None
693
  self.output_queue = asyncio.Queue()
694
+ self.web_search_enabled = False
695
+ self.search_client = None
696
  self.function_call_in_progress = False
697
  self.current_function_args = ""
698
+ self.call_id = None
699
 
700
  def copy(self):
701
+ return OpenAIHandler()
702
 
703
  async def search_web(self, query: str) -> str:
704
  """Perform web search and return formatted results"""
 
776
  elif event.type == "response.function_call_arguments.start":
777
  self.function_call_in_progress = True
778
  self.current_function_args = ""
779
+ self.call_id = event.call_id if hasattr(event, 'call_id') else None
780
 
781
  elif event.type == "response.function_call_arguments.delta":
782
  if self.function_call_in_progress:
 
799
  search_results = await self.search_web(query)
800
 
801
  # Send function result back to the model
802
+ if self.connection and self.call_id:
803
  await self.connection.conversation.item.create(
804
  item={
805
  "type": "function_call_output",
806
+ "call_id": self.call_id,
807
  "output": search_results
808
  }
809
  )
 
814
  finally:
815
  self.function_call_in_progress = False
816
  self.current_function_args = ""
817
+ self.call_id = None
818
 
819
  async def receive(self, frame: tuple[int, np.ndarray]) -> None:
820
  if not self.connection:
 
833
  self.connection = None
834
 
835
 
836
+ # Store connection settings
837
+ connection_settings = {}
838
 
839
  # Initialize search client
840
  brave_api_key = os.getenv("BSEARCH_API")
841
  search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
842
 
 
 
 
843
  def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
844
  chatbot.append({"role": "assistant", "content": response.transcript})
845
  return chatbot
846
 
847
+ # Create chatbot component
848
+ chatbot = gr.Chatbot(type="messages")
849
+
850
+ # Create handler
851
+ handler = OpenAIHandler()
852
+ handler.search_client = search_client
853
+
854
+ # Create stream
855
+ stream = Stream(
856
+ handler,
857
+ mode="send-receive",
858
+ modality="audio",
859
+ additional_inputs=[chatbot],
860
+ additional_outputs=[chatbot],
861
+ additional_outputs_handler=update_chatbot,
862
+ rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
863
+ concurrency_limit=5 if get_space() else None,
864
+ time_limit=90 if get_space() else None,
865
+ )
866
 
867
+ app = FastAPI()
868
+ stream.mount(app)
869
+
870
+ @app.post("/update-search-setting")
871
+ async def update_search_setting(request: dict):
872
+ """Update web search setting for a connection"""
873
  webrtc_id = request.get("webrtc_id")
874
+ web_search_enabled = request.get("web_search_enabled", False)
875
 
876
+ if webrtc_id:
877
+ connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
878
+ # Update handler setting
879
+ handler.web_search_enabled = web_search_enabled
880
+
881
+ # Reconnect with updated settings if needed
882
+ if handler.connection:
883
+ await handler.shutdown()
884
+ await handler.start_up()
 
 
 
 
 
 
 
885
 
886
+ return {"status": "updated"}
887
+
888
+ # Override the stream's offer handler
889
+ original_offer = stream.offer
890
+
891
+ async def custom_offer(request: dict):
892
+ """Handle WebRTC offer with web search preference"""
893
+ webrtc_id = request.get("webrtc_id")
894
+ web_search_enabled = request.get("web_search_enabled", False)
895
 
896
+ if webrtc_id:
897
+ connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
898
+ handler.web_search_enabled = web_search_enabled
899
 
900
+ return await original_offer(request)
 
901
 
902
+ # Replace the offer method
903
+ stream.offer = custom_offer
904
 
905
  @app.get("/outputs")
906
  async def outputs(webrtc_id: str):
907
  """Stream outputs including search events"""
908
  async def output_stream():
909
+ async for output in stream.output_stream(webrtc_id):
 
 
 
 
910
  if hasattr(output, 'args') and output.args:
911
  # Check if it's a search event
912
  if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
 
929
 
930
  if __name__ == "__main__":
931
  import uvicorn
932
+
933
+ mode = os.getenv("MODE")
934
+ if mode == "UI":
935
+ stream.ui.launch(server_port=7860)
936
+ elif mode == "PHONE":
937
+ stream.fastphone(host="0.0.0.0", port=7860)
938
+ else:
939
+ uvicorn.run(app, host="0.0.0.0", port=7860)