Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -397,18 +397,7 @@ HTML_CONTENT = """<!DOCTYPE html>
|
|
397 |
searchToggle.addEventListener('click', () => {
|
398 |
webSearchEnabled = !webSearchEnabled;
|
399 |
searchToggle.classList.toggle('active', webSearchEnabled);
|
400 |
-
|
401 |
-
// Update server-side settings if connected
|
402 |
-
if (webrtc_id) {
|
403 |
-
fetch('/update-search-setting', {
|
404 |
-
method: 'POST',
|
405 |
-
headers: { 'Content-Type': 'application/json' },
|
406 |
-
body: JSON.stringify({
|
407 |
-
webrtc_id: webrtc_id,
|
408 |
-
web_search_enabled: webSearchEnabled
|
409 |
-
})
|
410 |
-
});
|
411 |
-
}
|
412 |
});
|
413 |
|
414 |
function updateStatus(state) {
|
@@ -577,7 +566,7 @@ HTML_CONTENT = """<!DOCTYPE html>
|
|
577 |
});
|
578 |
eventSource.addEventListener("search", (event) => {
|
579 |
const eventJson = JSON.parse(event.data);
|
580 |
-
if (eventJson.
|
581 |
addMessage("search-result", `์น ๊ฒ์ ์ค: "${eventJson.query}"`);
|
582 |
}
|
583 |
});
|
@@ -682,7 +671,7 @@ class BraveSearchClient:
|
|
682 |
|
683 |
|
684 |
class OpenAIHandler(AsyncStreamHandler):
|
685 |
-
def __init__(self) -> None:
|
686 |
super().__init__(
|
687 |
expected_layout="mono",
|
688 |
output_sample_rate=SAMPLE_RATE,
|
@@ -691,20 +680,23 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
691 |
)
|
692 |
self.connection = None
|
693 |
self.output_queue = asyncio.Queue()
|
694 |
-
self.web_search_enabled =
|
695 |
-
self.search_client =
|
696 |
self.function_call_in_progress = False
|
697 |
self.current_function_args = ""
|
698 |
-
self.
|
|
|
699 |
|
700 |
def copy(self):
|
701 |
-
|
|
|
702 |
|
703 |
async def search_web(self, query: str) -> str:
|
704 |
"""Perform web search and return formatted results"""
|
705 |
if not self.search_client or not self.web_search_enabled:
|
706 |
return "์น ๊ฒ์์ด ๋นํ์ฑํ๋์ด ์์ต๋๋ค."
|
707 |
|
|
|
708 |
results = await self.search_client.search(query)
|
709 |
if not results:
|
710 |
return f"'{query}'์ ๋ํ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."
|
@@ -722,6 +714,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
722 |
|
723 |
async def start_up(self):
|
724 |
"""Connect to realtime API with function calling enabled"""
|
|
|
725 |
self.client = openai.AsyncOpenAI()
|
726 |
|
727 |
# Define the web search function
|
@@ -744,6 +737,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
744 |
}
|
745 |
}
|
746 |
}]
|
|
|
747 |
|
748 |
async with self.client.beta.realtime.connect(
|
749 |
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
@@ -757,6 +751,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
757 |
|
758 |
await conn.session.update(session=session_update)
|
759 |
self.connection = conn
|
|
|
760 |
|
761 |
async for event in self.connection:
|
762 |
if event.type == "response.audio_transcript.done":
|
@@ -774,9 +769,10 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
774 |
|
775 |
# Handle function calls
|
776 |
elif event.type == "response.function_call_arguments.start":
|
|
|
777 |
self.function_call_in_progress = True
|
778 |
self.current_function_args = ""
|
779 |
-
self.
|
780 |
|
781 |
elif event.type == "response.function_call_arguments.delta":
|
782 |
if self.function_call_in_progress:
|
@@ -784,6 +780,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
784 |
|
785 |
elif event.type == "response.function_call_arguments.done":
|
786 |
if self.function_call_in_progress:
|
|
|
787 |
try:
|
788 |
args = json.loads(self.current_function_args)
|
789 |
query = args.get("query", "")
|
@@ -791,19 +788,19 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
791 |
# Emit search event to client
|
792 |
await self.output_queue.put(AdditionalOutputs({
|
793 |
"type": "search",
|
794 |
-
"query": query
|
795 |
-
"results": True
|
796 |
}))
|
797 |
|
798 |
# Perform the search
|
799 |
search_results = await self.search_web(query)
|
|
|
800 |
|
801 |
# Send function result back to the model
|
802 |
-
if self.connection and self.
|
803 |
await self.connection.conversation.item.create(
|
804 |
item={
|
805 |
"type": "function_call_output",
|
806 |
-
"call_id": self.
|
807 |
"output": search_results
|
808 |
}
|
809 |
)
|
@@ -814,7 +811,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
814 |
finally:
|
815 |
self.function_call_in_progress = False
|
816 |
self.current_function_args = ""
|
817 |
-
self.
|
818 |
|
819 |
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
820 |
if not self.connection:
|
@@ -833,80 +830,77 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
833 |
self.connection = None
|
834 |
|
835 |
|
836 |
-
# Store
|
837 |
-
|
838 |
|
839 |
# Initialize search client
|
840 |
brave_api_key = os.getenv("BSEARCH_API")
|
841 |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
|
|
|
842 |
|
843 |
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
844 |
chatbot.append({"role": "assistant", "content": response.transcript})
|
845 |
return chatbot
|
846 |
|
847 |
-
# Create chatbot component
|
848 |
-
chatbot = gr.Chatbot(type="messages")
|
849 |
-
|
850 |
-
# Create handler
|
851 |
-
handler = OpenAIHandler()
|
852 |
-
handler.search_client = search_client
|
853 |
-
|
854 |
-
# Create stream
|
855 |
-
stream = Stream(
|
856 |
-
handler,
|
857 |
-
mode="send-receive",
|
858 |
-
modality="audio",
|
859 |
-
additional_inputs=[chatbot],
|
860 |
-
additional_outputs=[chatbot],
|
861 |
-
additional_outputs_handler=update_chatbot,
|
862 |
-
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
863 |
-
concurrency_limit=5 if get_space() else None,
|
864 |
-
time_limit=90 if get_space() else None,
|
865 |
-
)
|
866 |
|
|
|
867 |
app = FastAPI()
|
868 |
-
stream.mount(app)
|
869 |
|
870 |
-
@app.post("/
|
871 |
-
async def
|
872 |
-
"""
|
873 |
webrtc_id = request.get("webrtc_id")
|
874 |
web_search_enabled = request.get("web_search_enabled", False)
|
875 |
|
876 |
-
|
877 |
-
connection_settings[webrtc_id] = {"web_search_enabled": web_search_enabled}
|
878 |
-
# Update handler setting
|
879 |
-
handler.web_search_enabled = web_search_enabled
|
880 |
-
|
881 |
-
# Reconnect with updated settings if needed
|
882 |
-
if handler.connection:
|
883 |
-
await handler.shutdown()
|
884 |
-
await handler.start_up()
|
885 |
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
895 |
|
896 |
-
|
897 |
-
|
898 |
-
handler.web_search_enabled = web_search_enabled
|
899 |
|
900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
901 |
|
902 |
-
# Replace the offer method
|
903 |
-
stream.offer = custom_offer
|
904 |
|
905 |
@app.get("/outputs")
|
906 |
async def outputs(webrtc_id: str):
|
907 |
"""Stream outputs including search events"""
|
908 |
async def output_stream():
|
909 |
-
|
|
|
|
|
|
|
|
|
910 |
if hasattr(output, 'args') and output.args:
|
911 |
# Check if it's a search event
|
912 |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
|
@@ -932,8 +926,38 @@ if __name__ == "__main__":
|
|
932 |
|
933 |
mode = os.getenv("MODE")
|
934 |
if mode == "UI":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
935 |
stream.ui.launch(server_port=7860)
|
936 |
elif mode == "PHONE":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
937 |
stream.fastphone(host="0.0.0.0", port=7860)
|
938 |
else:
|
939 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
397 |
searchToggle.addEventListener('click', () => {
|
398 |
webSearchEnabled = !webSearchEnabled;
|
399 |
searchToggle.classList.toggle('active', webSearchEnabled);
|
400 |
+
console.log('Web search enabled:', webSearchEnabled);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
});
|
402 |
|
403 |
function updateStatus(state) {
|
|
|
566 |
});
|
567 |
eventSource.addEventListener("search", (event) => {
|
568 |
const eventJson = JSON.parse(event.data);
|
569 |
+
if (eventJson.query) {
|
570 |
addMessage("search-result", `์น ๊ฒ์ ์ค: "${eventJson.query}"`);
|
571 |
}
|
572 |
});
|
|
|
671 |
|
672 |
|
673 |
class OpenAIHandler(AsyncStreamHandler):
|
674 |
+
def __init__(self, web_search_enabled: bool = False, search_client: Optional[BraveSearchClient] = None, webrtc_id: str = None) -> None:
|
675 |
super().__init__(
|
676 |
expected_layout="mono",
|
677 |
output_sample_rate=SAMPLE_RATE,
|
|
|
680 |
)
|
681 |
self.connection = None
|
682 |
self.output_queue = asyncio.Queue()
|
683 |
+
self.web_search_enabled = web_search_enabled
|
684 |
+
self.search_client = search_client
|
685 |
self.function_call_in_progress = False
|
686 |
self.current_function_args = ""
|
687 |
+
self.current_call_id = None
|
688 |
+
self.webrtc_id = webrtc_id
|
689 |
|
690 |
def copy(self):
|
691 |
+
# Return a new instance with the same settings
|
692 |
+
return OpenAIHandler(self.web_search_enabled, self.search_client, self.webrtc_id)
|
693 |
|
694 |
async def search_web(self, query: str) -> str:
|
695 |
"""Perform web search and return formatted results"""
|
696 |
if not self.search_client or not self.web_search_enabled:
|
697 |
return "์น ๊ฒ์์ด ๋นํ์ฑํ๋์ด ์์ต๋๋ค."
|
698 |
|
699 |
+
print(f"Searching web for: {query}")
|
700 |
results = await self.search_client.search(query)
|
701 |
if not results:
|
702 |
return f"'{query}'์ ๋ํ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."
|
|
|
714 |
|
715 |
async def start_up(self):
|
716 |
"""Connect to realtime API with function calling enabled"""
|
717 |
+
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
|
718 |
self.client = openai.AsyncOpenAI()
|
719 |
|
720 |
# Define the web search function
|
|
|
737 |
}
|
738 |
}
|
739 |
}]
|
740 |
+
print("Web search function added to tools")
|
741 |
|
742 |
async with self.client.beta.realtime.connect(
|
743 |
model="gpt-4o-mini-realtime-preview-2024-12-17"
|
|
|
751 |
|
752 |
await conn.session.update(session=session_update)
|
753 |
self.connection = conn
|
754 |
+
print(f"Connected with tools: {len(tools)} functions")
|
755 |
|
756 |
async for event in self.connection:
|
757 |
if event.type == "response.audio_transcript.done":
|
|
|
769 |
|
770 |
# Handle function calls
|
771 |
elif event.type == "response.function_call_arguments.start":
|
772 |
+
print(f"Function call started, call_id: {getattr(event, 'call_id', 'unknown')}")
|
773 |
self.function_call_in_progress = True
|
774 |
self.current_function_args = ""
|
775 |
+
self.current_call_id = getattr(event, 'call_id', None)
|
776 |
|
777 |
elif event.type == "response.function_call_arguments.delta":
|
778 |
if self.function_call_in_progress:
|
|
|
780 |
|
781 |
elif event.type == "response.function_call_arguments.done":
|
782 |
if self.function_call_in_progress:
|
783 |
+
print(f"Function call done, args: {self.current_function_args}")
|
784 |
try:
|
785 |
args = json.loads(self.current_function_args)
|
786 |
query = args.get("query", "")
|
|
|
788 |
# Emit search event to client
|
789 |
await self.output_queue.put(AdditionalOutputs({
|
790 |
"type": "search",
|
791 |
+
"query": query
|
|
|
792 |
}))
|
793 |
|
794 |
# Perform the search
|
795 |
search_results = await self.search_web(query)
|
796 |
+
print(f"Search results length: {len(search_results)}")
|
797 |
|
798 |
# Send function result back to the model
|
799 |
+
if self.connection and self.current_call_id:
|
800 |
await self.connection.conversation.item.create(
|
801 |
item={
|
802 |
"type": "function_call_output",
|
803 |
+
"call_id": self.current_call_id,
|
804 |
"output": search_results
|
805 |
}
|
806 |
)
|
|
|
811 |
finally:
|
812 |
self.function_call_in_progress = False
|
813 |
self.current_function_args = ""
|
814 |
+
self.current_call_id = None
|
815 |
|
816 |
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
817 |
if not self.connection:
|
|
|
830 |
self.connection = None
|
831 |
|
832 |
|
833 |
+
# Store handlers by webrtc_id
|
834 |
+
handlers = {}
|
835 |
|
836 |
# Initialize search client
|
837 |
brave_api_key = os.getenv("BSEARCH_API")
|
838 |
search_client = BraveSearchClient(brave_api_key) if brave_api_key else None
|
839 |
+
print(f"Search client initialized: {search_client is not None}, API key present: {bool(brave_api_key)}")
|
840 |
|
841 |
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
|
842 |
chatbot.append({"role": "assistant", "content": response.transcript})
|
843 |
return chatbot
|
844 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
845 |
|
846 |
+
# Create base stream without handler for mounting
|
847 |
app = FastAPI()
|
|
|
848 |
|
849 |
+
@app.post("/webrtc/offer")
|
850 |
+
async def handle_offer(request: dict):
|
851 |
+
"""Handle WebRTC offer and create a new handler for this connection"""
|
852 |
webrtc_id = request.get("webrtc_id")
|
853 |
web_search_enabled = request.get("web_search_enabled", False)
|
854 |
|
855 |
+
print(f"Received offer - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
|
857 |
+
# Create a new handler for this connection
|
858 |
+
handler = OpenAIHandler(
|
859 |
+
web_search_enabled=web_search_enabled,
|
860 |
+
search_client=search_client,
|
861 |
+
webrtc_id=webrtc_id
|
862 |
+
)
|
863 |
+
|
864 |
+
# Store the handler
|
865 |
+
handlers[webrtc_id] = handler
|
866 |
+
|
867 |
+
# Create chatbot instance for this connection
|
868 |
+
chatbot = gr.Chatbot(type="messages")
|
869 |
+
|
870 |
+
# Create stream for this connection
|
871 |
+
stream = Stream(
|
872 |
+
handler,
|
873 |
+
mode="send-receive",
|
874 |
+
modality="audio",
|
875 |
+
additional_inputs=[chatbot],
|
876 |
+
additional_outputs=[chatbot],
|
877 |
+
additional_outputs_handler=update_chatbot,
|
878 |
+
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
879 |
+
concurrency_limit=5 if get_space() else None,
|
880 |
+
time_limit=90 if get_space() else None,
|
881 |
+
)
|
882 |
|
883 |
+
# Mount stream temporarily
|
884 |
+
stream.mount(app)
|
|
|
885 |
|
886 |
+
# Store stream reference
|
887 |
+
handler.stream = stream
|
888 |
+
|
889 |
+
# Process the offer
|
890 |
+
response = await stream.offer(request)
|
891 |
+
|
892 |
+
return response
|
893 |
|
|
|
|
|
894 |
|
895 |
@app.get("/outputs")
|
896 |
async def outputs(webrtc_id: str):
|
897 |
"""Stream outputs including search events"""
|
898 |
async def output_stream():
|
899 |
+
handler = handlers.get(webrtc_id)
|
900 |
+
if not handler or not hasattr(handler, 'stream'):
|
901 |
+
return
|
902 |
+
|
903 |
+
async for output in handler.stream.output_stream(webrtc_id):
|
904 |
if hasattr(output, 'args') and output.args:
|
905 |
# Check if it's a search event
|
906 |
if isinstance(output.args[0], dict) and output.args[0].get('type') == 'search':
|
|
|
926 |
|
927 |
mode = os.getenv("MODE")
|
928 |
if mode == "UI":
|
929 |
+
# For UI mode, we need a base handler and stream
|
930 |
+
base_handler = OpenAIHandler()
|
931 |
+
chatbot = gr.Chatbot(type="messages")
|
932 |
+
stream = Stream(
|
933 |
+
base_handler,
|
934 |
+
mode="send-receive",
|
935 |
+
modality="audio",
|
936 |
+
additional_inputs=[chatbot],
|
937 |
+
additional_outputs=[chatbot],
|
938 |
+
additional_outputs_handler=update_chatbot,
|
939 |
+
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
940 |
+
concurrency_limit=5 if get_space() else None,
|
941 |
+
time_limit=90 if get_space() else None,
|
942 |
+
)
|
943 |
+
stream.mount(app)
|
944 |
stream.ui.launch(server_port=7860)
|
945 |
elif mode == "PHONE":
|
946 |
+
# Similar for phone mode
|
947 |
+
base_handler = OpenAIHandler()
|
948 |
+
chatbot = gr.Chatbot(type="messages")
|
949 |
+
stream = Stream(
|
950 |
+
base_handler,
|
951 |
+
mode="send-receive",
|
952 |
+
modality="audio",
|
953 |
+
additional_inputs=[chatbot],
|
954 |
+
additional_outputs=[chatbot],
|
955 |
+
additional_outputs_handler=update_chatbot,
|
956 |
+
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
|
957 |
+
concurrency_limit=5 if get_space() else None,
|
958 |
+
time_limit=90 if get_space() else None,
|
959 |
+
)
|
960 |
+
stream.mount(app)
|
961 |
stream.fastphone(host="0.0.0.0", port=7860)
|
962 |
else:
|
963 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|