Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -670,8 +670,12 @@ class BraveSearchClient:
|
|
670 |
return []
|
671 |
|
672 |
|
|
|
|
|
|
|
|
|
673 |
class OpenAIHandler(AsyncStreamHandler):
|
674 |
-
def __init__(self
|
675 |
super().__init__(
|
676 |
expected_layout="mono",
|
677 |
output_sample_rate=SAMPLE_RATE,
|
@@ -680,16 +684,15 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
680 |
)
|
681 |
self.connection = None
|
682 |
self.output_queue = asyncio.Queue()
|
683 |
-
self.
|
684 |
-
self.search_client = search_client
|
685 |
self.function_call_in_progress = False
|
686 |
self.current_function_args = ""
|
687 |
self.current_call_id = None
|
688 |
-
self.webrtc_id =
|
|
|
689 |
|
690 |
def copy(self):
|
691 |
-
|
692 |
-
return OpenAIHandler(self.web_search_enabled, self.search_client, self.webrtc_id)
|
693 |
|
694 |
async def search_web(self, query: str) -> str:
|
695 |
"""Perform web search and return formatted results"""
|
@@ -714,6 +717,10 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
714 |
|
715 |
async def start_up(self):
|
716 |
"""Connect to realtime API with function calling enabled"""
|
|
|
|
|
|
|
|
|
717 |
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
|
718 |
self.client = openai.AsyncOpenAI()
|
719 |
|
@@ -724,7 +731,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
724 |
"type": "function",
|
725 |
"function": {
|
726 |
"name": "web_search",
|
727 |
-
"description": "Search the web for information",
|
728 |
"parameters": {
|
729 |
"type": "object",
|
730 |
"properties": {
|
@@ -749,6 +756,16 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
749 |
"tool_choice": "auto" if tools else "none"
|
750 |
}
|
751 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
752 |
await conn.session.update(session=session_update)
|
753 |
self.connection = conn
|
754 |
print(f"Connected with tools: {len(tools)} functions")
|
@@ -769,7 +786,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
769 |
|
770 |
# Handle function calls
|
771 |
elif event.type == "response.function_call_arguments.start":
|
772 |
-
print(f"Function call started
|
773 |
self.function_call_in_progress = True
|
774 |
self.current_function_args = ""
|
775 |
self.current_call_id = getattr(event, 'call_id', None)
|
@@ -830,9 +847,6 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
830 |
self.connection = None
|
831 |
|
832 |
|
833 |
-
# Store handlers by webrtc_id
|
834 |
-
handlers = {}
|
835 |
-
|
836 |
# Initialize search client
|
837 |
brave_api_key = os.getenv("BSEARCH_API")
|
838 |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
|
@@ -842,65 +856,73 @@ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEve
|
|
842 |
chatbot.append({"role": "assistant", "content": response.transcript})
|
843 |
return chatbot
|
844 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
845 |
|
846 |
-
# Create base stream without handler for mounting
|
847 |
app = FastAPI()
|
|
|
848 |
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
webrtc_id=webrtc_id
|
862 |
-
)
|
863 |
-
|
864 |
-
# Store the handler
|
865 |
-
handlers[webrtc_id] = handler
|
866 |
-
|
867 |
-
# Create chatbot instance for this connection
|
868 |
-
chatbot = gr.Chatbot(type="messages")
|
869 |
-
|
870 |
-
# Create stream for this connection
|
871 |
-
stream = Stream(
|
872 |
-
handler,
|
873 |
-
mode="send-receive",
|
874 |
-
modality="audio",
|
875 |
-
additional_inputs=[chatbot],
|
876 |
-
additional_outputs=[chatbot],
|
877 |
-
additional_outputs_handler=update_chatbot,
|
878 |
-
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
879 |
-
concurrency_limit=5 if get_space() else None,
|
880 |
-
time_limit=90 if get_space() else None,
|
881 |
-
)
|
882 |
|
883 |
-
#
|
884 |
-
|
|
|
|
|
|
|
|
|
|
|
885 |
|
886 |
-
|
887 |
-
handler.stream = stream
|
888 |
|
889 |
-
#
|
890 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
891 |
|
892 |
-
|
|
|
893 |
|
|
|
|
|
894 |
|
895 |
@app.get("/outputs")
|
896 |
async def outputs(webrtc_id: str):
|
897 |
"""Stream outputs including search events"""
|
898 |
async def output_stream():
|
899 |
-
|
900 |
-
if not handler or not hasattr(handler, 'stream'):
|
901 |
-
return
|
902 |
-
|
903 |
-
async for output in handler.stream.output_stream(webrtc_id):
|
904 |
if hasattr(output, 'args') and output.args:
|
905 |
# Check if it's a search event
|
906 |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
|
@@ -926,38 +948,8 @@ if __name__ == "__main__":
|
|
926 |
|
927 |
mode = os.getenv("MODE")
|
928 |
if mode == "UI":
|
929 |
-
# For UI mode, we need a base handler and stream
|
930 |
-
base_handler = OpenAIHandler()
|
931 |
-
chatbot = gr.Chatbot(type="messages")
|
932 |
-
stream = Stream(
|
933 |
-
base_handler,
|
934 |
-
mode="send-receive",
|
935 |
-
modality="audio",
|
936 |
-
additional_inputs=[chatbot],
|
937 |
-
additional_outputs=[chatbot],
|
938 |
-
additional_outputs_handler=update_chatbot,
|
939 |
-
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
940 |
-
concurrency_limit=5 if get_space() else None,
|
941 |
-
time_limit=90 if get_space() else None,
|
942 |
-
)
|
943 |
-
stream.mount(app)
|
944 |
stream.ui.launch(server_port=7860)
|
945 |
elif mode == "PHONE":
|
946 |
-
# Similar for phone mode
|
947 |
-
base_handler = OpenAIHandler()
|
948 |
-
chatbot = gr.Chatbot(type="messages")
|
949 |
-
stream = Stream(
|
950 |
-
base_handler,
|
951 |
-
mode="send-receive",
|
952 |
-
modality="audio",
|
953 |
-
additional_inputs=[chatbot],
|
954 |
-
additional_outputs=[chatbot],
|
955 |
-
additional_outputs_handler=update_chatbot,
|
956 |
-
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
957 |
-
concurrency_limit=5 if get_space() else None,
|
958 |
-
time_limit=90 if get_space() else None,
|
959 |
-
)
|
960 |
-
stream.mount(app)
|
961 |
stream.fastphone(host="0.0.0.0", port=7860)
|
962 |
else:
|
963 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
670 |
return []
|
671 |
|
672 |
|
673 |
+
# Global state for web search settings
|
674 |
+
web_search_settings = {}
|
675 |
+
|
676 |
+
|
677 |
class OpenAIHandler(AsyncStreamHandler):
|
678 |
+
def __init__(self) -> None:
|
679 |
super().__init__(
|
680 |
expected_layout="mono",
|
681 |
output_sample_rate=SAMPLE_RATE,
|
|
|
684 |
)
|
685 |
self.connection = None
|
686 |
self.output_queue = asyncio.Queue()
|
687 |
+
self.search_client = None
|
|
|
688 |
self.function_call_in_progress = False
|
689 |
self.current_function_args = ""
|
690 |
self.current_call_id = None
|
691 |
+
self.webrtc_id = None
|
692 |
+
self.web_search_enabled = False
|
693 |
|
694 |
def copy(self):
|
695 |
+
return OpenAIHandler()
|
|
|
696 |
|
697 |
async def search_web(self, query: str) -> str:
|
698 |
"""Perform web search and return formatted results"""
|
|
|
717 |
|
718 |
async def start_up(self):
|
719 |
"""Connect to realtime API with function calling enabled"""
|
720 |
+
# Get web search setting from global state
|
721 |
+
if self.webrtc_id and self.webrtc_id in web_search_settings:
|
722 |
+
self.web_search_enabled = web_search_settings[self.webrtc_id]
|
723 |
+
|
724 |
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
|
725 |
self.client = openai.AsyncOpenAI()
|
726 |
|
|
|
731 |
"type": "function",
|
732 |
"function": {
|
733 |
"name": "web_search",
|
734 |
+
"description": "Search the web for information when the user asks questions about current events, news, or topics that require up-to-date information",
|
735 |
"parameters": {
|
736 |
"type": "object",
|
737 |
"properties": {
|
|
|
756 |
"tool_choice": "auto" if tools else "none"
|
757 |
}
|
758 |
|
759 |
+
# Add instructions to use web search
|
760 |
+
if self.web_search_enabled:
|
761 |
+
session_update["instructions"] = (
|
762 |
+
"You are a helpful assistant. When users ask about current events, "
|
763 |
+
"news, recent information, or topics that might have changed recently, "
|
764 |
+
"use the web_search function to find the most up-to-date information. "
|
765 |
+
"Always search when the user asks about: weather, news, current events, "
|
766 |
+
"recent updates, prices, or any time-sensitive information."
|
767 |
+
)
|
768 |
+
|
769 |
await conn.session.update(session=session_update)
|
770 |
self.connection = conn
|
771 |
print(f"Connected with tools: {len(tools)} functions")
|
|
|
786 |
|
787 |
# Handle function calls
|
788 |
elif event.type == "response.function_call_arguments.start":
|
789 |
+
print(f"Function call started")
|
790 |
self.function_call_in_progress = True
|
791 |
self.current_function_args = ""
|
792 |
self.current_call_id = getattr(event, 'call_id', None)
|
|
|
847 |
self.connection = None
|
848 |
|
849 |
|
|
|
|
|
|
|
850 |
# Initialize search client
|
851 |
brave_api_key = os.getenv("BSEARCH_API")
|
852 |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
|
|
|
856 |
chatbot.append({"role": "assistant", "content": response.transcript})
|
857 |
return chatbot
|
858 |
|
859 |
+
# Create chatbot component
|
860 |
+
chatbot = gr.Chatbot(type="messages")
|
861 |
+
|
862 |
+
# Create base handler and stream
|
863 |
+
handler = OpenAIHandler()
|
864 |
+
handler.search_client = search_client
|
865 |
+
|
866 |
+
stream = Stream(
|
867 |
+
handler,
|
868 |
+
mode="send-receive",
|
869 |
+
modality="audio",
|
870 |
+
additional_inputs=[chatbot],
|
871 |
+
additional_outputs=[chatbot],
|
872 |
+
additional_outputs_handler=update_chatbot,
|
873 |
+
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
874 |
+
concurrency_limit=5 if get_space() else None,
|
875 |
+
time_limit=90 if get_space() else None,
|
876 |
+
)
|
877 |
|
|
|
878 |
app = FastAPI()
|
879 |
+
stream.mount(app)
|
880 |
|
881 |
+
# Override the offer handler to capture web search settings
|
882 |
+
original_offer_handler = stream.offer
|
883 |
+
|
884 |
+
async def custom_offer_handler(request):
|
885 |
+
"""Custom offer handler to process web search settings"""
|
886 |
+
# Extract web search setting from request
|
887 |
+
if hasattr(request, 'web_search_enabled'):
|
888 |
+
web_search_enabled = request.web_search_enabled
|
889 |
+
elif isinstance(request, dict) and 'web_search_enabled' in request:
|
890 |
+
web_search_enabled = request['web_search_enabled']
|
891 |
+
else:
|
892 |
+
web_search_enabled = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
893 |
|
894 |
+
# Extract webrtc_id
|
895 |
+
if hasattr(request, 'webrtc_id'):
|
896 |
+
webrtc_id = request.webrtc_id
|
897 |
+
elif isinstance(request, dict) and 'webrtc_id' in request:
|
898 |
+
webrtc_id = request['webrtc_id']
|
899 |
+
else:
|
900 |
+
webrtc_id = None
|
901 |
|
902 |
+
print(f"Received offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
|
|
|
903 |
|
904 |
+
# Store web search setting
|
905 |
+
if webrtc_id:
|
906 |
+
web_search_settings[webrtc_id] = web_search_enabled
|
907 |
+
handler.webrtc_id = webrtc_id
|
908 |
+
handler.web_search_enabled = web_search_enabled
|
909 |
+
|
910 |
+
# Restart handler if already connected to update settings
|
911 |
+
if handler.connection:
|
912 |
+
await handler.shutdown()
|
913 |
+
await handler.start_up()
|
914 |
|
915 |
+
# Call original offer handler
|
916 |
+
return await original_offer_handler(request)
|
917 |
|
918 |
+
# Replace the offer method
|
919 |
+
stream.offer = custom_offer_handler
|
920 |
|
921 |
@app.get("/outputs")
|
922 |
async def outputs(webrtc_id: str):
|
923 |
"""Stream outputs including search events"""
|
924 |
async def output_stream():
|
925 |
+
async for output in stream.output_stream(webrtc_id):
|
|
|
|
|
|
|
|
|
926 |
if hasattr(output, 'args') and output.args:
|
927 |
# Check if it's a search event
|
928 |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
|
|
|
948 |
|
949 |
mode = os.getenv("MODE")
|
950 |
if mode == "UI":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
951 |
stream.ui.launch(server_port=7860)
|
952 |
elif mode == "PHONE":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
953 |
stream.fastphone(host="0.0.0.0", port=7860)
|
954 |
else:
|
955 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|