Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,7 @@ from gradio.utils import get_space
|
|
19 |
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
|
20 |
import httpx
|
21 |
from typing import Optional, List, Dict
|
|
|
22 |
|
23 |
load_dotenv()
|
24 |
|
@@ -396,6 +397,18 @@ HTML_CONTENT = """<!DOCTYPE html>
|
|
396 |
searchToggle.addEventListener('click', () => {
|
397 |
webSearchEnabled = !webSearchEnabled;
|
398 |
searchToggle.classList.toggle('active', webSearchEnabled);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
});
|
400 |
|
401 |
function updateStatus(state) {
|
@@ -669,7 +682,7 @@ class BraveSearchClient:
|
|
669 |
|
670 |
|
671 |
class OpenAIHandler(AsyncStreamHandler):
|
672 |
-
def __init__(self
|
673 |
super().__init__(
|
674 |
expected_layout="mono",
|
675 |
output_sample_rate=SAMPLE_RATE,
|
@@ -678,13 +691,14 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
678 |
)
|
679 |
self.connection = None
|
680 |
self.output_queue = asyncio.Queue()
|
681 |
-
self.web_search_enabled =
|
682 |
-
self.search_client =
|
683 |
self.function_call_in_progress = False
|
684 |
self.current_function_args = ""
|
|
|
685 |
|
686 |
def copy(self):
|
687 |
-
return OpenAIHandler(
|
688 |
|
689 |
async def search_web(self, query: str) -> str:
|
690 |
"""Perform web search and return formatted results"""
|
@@ -762,6 +776,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
762 |
elif event.type == "response.function_call_arguments.start":
|
763 |
self.function_call_in_progress = True
|
764 |
self.current_function_args = ""
|
|
|
765 |
|
766 |
elif event.type == "response.function_call_arguments.delta":
|
767 |
if self.function_call_in_progress:
|
@@ -784,11 +799,11 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
784 |
search_results = await self.search_web(query)
|
785 |
|
786 |
# Send function result back to the model
|
787 |
-
if self.connection:
|
788 |
await self.connection.conversation.item.create(
|
789 |
item={
|
790 |
"type": "function_call_output",
|
791 |
-
"call_id":
|
792 |
"output": search_results
|
793 |
}
|
794 |
)
|
@@ -799,6 +814,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
799 |
finally:
|
800 |
self.function_call_in_progress = False
|
801 |
self.current_function_args = ""
|
|
|
802 |
|
803 |
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
804 |
if not self.connection:
|
@@ -817,63 +833,80 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
817 |
self.connection = None
|
818 |
|
819 |
|
820 |
-
# Store
|
821 |
-
|
822 |
|
823 |
# Initialize search client
|
824 |
brave_api_key = os.getenv("BSEARCH_API")
|
825 |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
|
826 |
|
827 |
-
app = FastAPI()
|
828 |
-
|
829 |
-
|
830 |
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
831 |
chatbot.append({"role": "assistant", "content": response.transcript})
|
832 |
return chatbot
|
833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
834 |
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
|
|
|
|
839 |
webrtc_id = request.get("webrtc_id")
|
|
|
840 |
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
handler
|
848 |
-
|
849 |
-
|
850 |
-
additional_inputs=[[]], # Empty chatbot state
|
851 |
-
additional_outputs=[[]],
|
852 |
-
additional_outputs_handler=update_chatbot,
|
853 |
-
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
854 |
-
concurrency_limit=5 if get_space() else None,
|
855 |
-
time_limit=90 if get_space() else None,
|
856 |
-
)
|
857 |
|
858 |
-
|
859 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
860 |
|
861 |
-
|
862 |
-
|
|
|
863 |
|
864 |
-
|
865 |
-
return await stream.offer(request)
|
866 |
|
|
|
|
|
867 |
|
868 |
@app.get("/outputs")
|
869 |
async def outputs(webrtc_id: str):
|
870 |
"""Stream outputs including search events"""
|
871 |
async def output_stream():
|
872 |
-
|
873 |
-
if not handler or not hasattr(handler, 'stream'):
|
874 |
-
return
|
875 |
-
|
876 |
-
async for output in handler.stream.output_stream(webrtc_id):
|
877 |
if hasattr(output, 'args') and output.args:
|
878 |
# Check if it's a search event
|
879 |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
|
@@ -896,4 +929,11 @@ async def index():
|
|
896 |
|
897 |
if __name__ == "__main__":
|
898 |
import uvicorn
|
899 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
|
20 |
import httpx
|
21 |
from typing import Optional, List, Dict
|
22 |
+
import gradio as gr
|
23 |
|
24 |
load_dotenv()
|
25 |
|
|
|
397 |
searchToggle.addEventListener('click', () => {
|
398 |
webSearchEnabled = !webSearchEnabled;
|
399 |
searchToggle.classList.toggle('active', webSearchEnabled);
|
400 |
+
|
401 |
+
// Update server-side settings if connected
|
402 |
+
if (webrtc_id) {
|
403 |
+
fetch('/update-search-setting', {
|
404 |
+
method: 'POST',
|
405 |
+
headers: { 'Content-Type': 'application/json' },
|
406 |
+
body: JSON.stringify({
|
407 |
+
webrtc_id: webrtc_id,
|
408 |
+
web_search_enabled: webSearchEnabled
|
409 |
+
})
|
410 |
+
});
|
411 |
+
}
|
412 |
});
|
413 |
|
414 |
function updateStatus(state) {
|
|
|
682 |
|
683 |
|
684 |
class OpenAIHandler(AsyncStreamHandler):
|
685 |
+
def __init__(self) -> None:
|
686 |
super().__init__(
|
687 |
expected_layout="mono",
|
688 |
output_sample_rate=SAMPLE_RATE,
|
|
|
691 |
)
|
692 |
self.connection = None
|
693 |
self.output_queue = asyncio.Queue()
|
694 |
+
self.web_search_enabled = False
|
695 |
+
self.search_client = None
|
696 |
self.function_call_in_progress = False
|
697 |
self.current_function_args = ""
|
698 |
+
self.call_id = None
|
699 |
|
700 |
def copy(self):
|
701 |
+
return OpenAIHandler()
|
702 |
|
703 |
async def search_web(self, query: str) -> str:
|
704 |
"""Perform web search and return formatted results"""
|
|
|
776 |
elif event.type == "response.function_call_arguments.start":
|
777 |
self.function_call_in_progress = True
|
778 |
self.current_function_args = ""
|
779 |
+
self.call_id = event.call_id if hasattr(event, 'call_id') else None
|
780 |
|
781 |
elif event.type == "response.function_call_arguments.delta":
|
782 |
if self.function_call_in_progress:
|
|
|
799 |
search_results = await self.search_web(query)
|
800 |
|
801 |
# Send function result back to the model
|
802 |
+
if self.connection and self.call_id:
|
803 |
await self.connection.conversation.item.create(
|
804 |
item={
|
805 |
"type": "function_call_output",
|
806 |
+
"call_id": self.call_id,
|
807 |
"output": search_results
|
808 |
}
|
809 |
)
|
|
|
814 |
finally:
|
815 |
self.function_call_in_progress = False
|
816 |
self.current_function_args = ""
|
817 |
+
self.call_id = None
|
818 |
|
819 |
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
820 |
if not self.connection:
|
|
|
833 |
self.connection = None
|
834 |
|
835 |
|
836 |
+
# Store connection settings
|
837 |
+
connection_settings = {}
|
838 |
|
839 |
# Initialize search client
|
840 |
brave_api_key = os.getenv("BSEARCH_API")
|
841 |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
|
842 |
|
|
|
|
|
|
|
843 |
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
844 |
chatbot.append({"role": "assistant", "content": response.transcript})
|
845 |
return chatbot
|
846 |
|
847 |
+
# Create chatbot component
|
848 |
+
chatbot = gr.Chatbot(type="messages")
|
849 |
+
|
850 |
+
# Create handler
|
851 |
+
handler = OpenAIHandler()
|
852 |
+
handler.search_client = search_client
|
853 |
+
|
854 |
+
# Create stream
|
855 |
+
stream = Stream(
|
856 |
+
handler,
|
857 |
+
mode="send-receive",
|
858 |
+
modality="audio",
|
859 |
+
additional_inputs=[chatbot],
|
860 |
+
additional_outputs=[chatbot],
|
861 |
+
additional_outputs_handler=update_chatbot,
|
862 |
+
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
863 |
+
concurrency_limit=5 if get_space() else None,
|
864 |
+
time_limit=90 if get_space() else None,
|
865 |
+
)
|
866 |
|
867 |
+
app = FastAPI()
|
868 |
+
stream.mount(app)
|
869 |
+
|
870 |
+
@app.post("/update-search-setting")
|
871 |
+
async def update_search_setting(request: dict):
|
872 |
+
"""Update web search setting for a connection"""
|
873 |
webrtc_id = request.get("webrtc_id")
|
874 |
+
web_search_enabled = request.get("web_search_enabled", False)
|
875 |
|
876 |
+
if webrtc_id:
|
877 |
+
connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
|
878 |
+
# Update handler setting
|
879 |
+
handler.web_search_enabled = web_search_enabled
|
880 |
+
|
881 |
+
# Reconnect with updated settings if needed
|
882 |
+
if handler.connection:
|
883 |
+
await handler.shutdown()
|
884 |
+
await handler.start_up()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
885 |
|
886 |
+
return {"status": "updated"}
|
887 |
+
|
888 |
+
# Override the stream's offer handler
|
889 |
+
original_offer = stream.offer
|
890 |
+
|
891 |
+
async def custom_offer(request: dict):
|
892 |
+
"""Handle WebRTC offer with web search preference"""
|
893 |
+
webrtc_id = request.get("webrtc_id")
|
894 |
+
web_search_enabled = request.get("web_search_enabled", False)
|
895 |
|
896 |
+
if webrtc_id:
|
897 |
+
connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
|
898 |
+
handler.web_search_enabled = web_search_enabled
|
899 |
|
900 |
+
return await original_offer(request)
|
|
|
901 |
|
902 |
+
# Replace the offer method
|
903 |
+
stream.offer = custom_offer
|
904 |
|
905 |
@app.get("/outputs")
|
906 |
async def outputs(webrtc_id: str):
|
907 |
"""Stream outputs including search events"""
|
908 |
async def output_stream():
|
909 |
+
async for output in stream.output_stream(webrtc_id):
|
|
|
|
|
|
|
|
|
910 |
if hasattr(output, 'args') and output.args:
|
911 |
# Check if it's a search event
|
912 |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
|
|
|
929 |
|
930 |
if __name__ == "__main__":
|
931 |
import uvicorn
|
932 |
+
|
933 |
+
mode = os.getenv("MODE")
|
934 |
+
if mode == "UI":
|
935 |
+
stream.ui.launch(server_port=7860)
|
936 |
+
elif mode == "PHONE":
|
937 |
+
stream.fastphone(host="0.0.0.0", port=7860)
|
938 |
+
else:
|
939 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|