Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import os
|
|
6 |
import numpy as np
|
7 |
import openai
|
8 |
from dotenv import load_dotenv
|
9 |
-
from fastapi import FastAPI
|
10 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
11 |
from fastrtc import (
|
12 |
AdditionalOutputs,
|
@@ -672,6 +672,8 @@ class BraveSearchClient:
|
|
672 |
|
673 |
# Global state for web search settings
|
674 |
web_search_settings = {}
|
|
|
|
|
675 |
|
676 |
|
677 |
class OpenAIHandler(AsyncStreamHandler):
|
@@ -684,15 +686,21 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
684 |
)
|
685 |
self.connection = None
|
686 |
self.output_queue = asyncio.Queue()
|
687 |
-
self.search_client =
|
688 |
self.function_call_in_progress = False
|
689 |
self.current_function_args = ""
|
690 |
self.current_call_id = None
|
691 |
self.webrtc_id = None
|
692 |
self.web_search_enabled = False
|
|
|
693 |
|
694 |
def copy(self):
|
695 |
-
|
|
|
|
|
|
|
|
|
|
|
696 |
|
697 |
async def search_web(self, query: str) -> str:
|
698 |
"""Perform web search and return formatted results"""
|
@@ -717,9 +725,10 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
717 |
|
718 |
async def start_up(self):
|
719 |
"""Connect to realtime API with function calling enabled"""
|
720 |
-
#
|
721 |
if self.webrtc_id and self.webrtc_id in web_search_settings:
|
722 |
self.web_search_enabled = web_search_settings[self.webrtc_id]
|
|
|
723 |
|
724 |
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
|
725 |
self.client = openai.AsyncOpenAI()
|
@@ -731,7 +740,7 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
731 |
"type": "function",
|
732 |
"function": {
|
733 |
"name": "web_search",
|
734 |
-
"description": "Search the web for information when the user asks questions about current events, news, or topics that require up-to-date information",
|
735 |
"parameters": {
|
736 |
"type": "object",
|
737 |
"properties": {
|
@@ -760,15 +769,16 @@ class OpenAIHandler(AsyncStreamHandler):
|
|
760 |
if self.web_search_enabled:
|
761 |
session_update["instructions"] = (
|
762 |
"You are a helpful assistant. When users ask about current events, "
|
763 |
-
"news, recent information, or topics that might have changed recently, "
|
764 |
"use the web_search function to find the most up-to-date information. "
|
765 |
-
"Always search when the user asks about: weather, news, current
|
766 |
-
"recent
|
767 |
)
|
768 |
|
769 |
await conn.session.update(session=session_update)
|
770 |
self.connection = conn
|
771 |
print(f"Connected with tools: {len(tools)} functions")
|
|
|
772 |
|
773 |
async for event in self.connection:
|
774 |
if event.type == "response.audio_transcript.done":
|
@@ -876,47 +886,46 @@ stream = Stream(
|
|
876 |
)
|
877 |
|
878 |
app = FastAPI()
|
|
|
|
|
879 |
stream.mount(app)
|
880 |
|
881 |
-
#
|
882 |
-
|
|
|
|
|
883 |
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
await handler.shutdown()
|
913 |
-
await handler.start_up()
|
914 |
-
|
915 |
-
# Call original offer handler
|
916 |
-
return await original_offer_handler(request)
|
917 |
|
918 |
-
#
|
919 |
-
|
920 |
|
921 |
@app.get("/outputs")
|
922 |
async def outputs(webrtc_id: str):
|
|
|
6 |
import numpy as np
|
7 |
import openai
|
8 |
from dotenv import load_dotenv
|
9 |
+
from fastapi import FastAPI, Request
|
10 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
11 |
from fastrtc import (
|
12 |
AdditionalOutputs,
|
|
|
672 |
|
673 |
# Global state for web search settings
|
674 |
web_search_settings = {}
|
675 |
+
current_webrtc_id = None
|
676 |
+
current_web_search_enabled = False
|
677 |
|
678 |
|
679 |
class OpenAIHandler(AsyncStreamHandler):
|
|
|
686 |
)
|
687 |
self.connection = None
|
688 |
self.output_queue = asyncio.Queue()
|
689 |
+
self.search_client = search_client # Use global search client
|
690 |
self.function_call_in_progress = False
|
691 |
self.current_function_args = ""
|
692 |
self.current_call_id = None
|
693 |
self.webrtc_id = None
|
694 |
self.web_search_enabled = False
|
695 |
+
self.startup_called = False
|
696 |
|
697 |
def copy(self):
|
698 |
+
# Create a new instance that inherits settings
|
699 |
+
new_handler = OpenAIHandler()
|
700 |
+
new_handler.search_client = self.search_client
|
701 |
+
new_handler.webrtc_id = current_webrtc_id # Use global variable
|
702 |
+
new_handler.web_search_enabled = current_web_search_enabled # Use global variable
|
703 |
+
return new_handler
|
704 |
|
705 |
async def search_web(self, query: str) -> str:
|
706 |
"""Perform web search and return formatted results"""
|
|
|
725 |
|
726 |
async def start_up(self):
|
727 |
"""Connect to realtime API with function calling enabled"""
|
728 |
+
# Check web search setting before connecting
|
729 |
if self.webrtc_id and self.webrtc_id in web_search_settings:
|
730 |
self.web_search_enabled = web_search_settings[self.webrtc_id]
|
731 |
+
print(f"Retrieved web_search_enabled={self.web_search_enabled} for webrtc_id={self.webrtc_id}")
|
732 |
|
733 |
print(f"Starting up handler with web_search_enabled={self.web_search_enabled}")
|
734 |
self.client = openai.AsyncOpenAI()
|
|
|
740 |
"type": "function",
|
741 |
"function": {
|
742 |
"name": "web_search",
|
743 |
+
"description": "Search the web for information when the user asks questions about current events, news, weather, prices, or topics that require up-to-date information",
|
744 |
"parameters": {
|
745 |
"type": "object",
|
746 |
"properties": {
|
|
|
769 |
if self.web_search_enabled:
|
770 |
session_update["instructions"] = (
|
771 |
"You are a helpful assistant. When users ask about current events, "
|
772 |
+
"news, recent information, weather, or topics that might have changed recently, "
|
773 |
"use the web_search function to find the most up-to-date information. "
|
774 |
+
"Always search when the user asks about: 날씨(weather), 뉴스(news), 현재(current), "
|
775 |
+
"최근(recent), 최신(latest), 오늘(today), or any time-sensitive information."
|
776 |
)
|
777 |
|
778 |
await conn.session.update(session=session_update)
|
779 |
self.connection = conn
|
780 |
print(f"Connected with tools: {len(tools)} functions")
|
781 |
+
self.startup_called = True
|
782 |
|
783 |
async for event in self.connection:
|
784 |
if event.type == "response.audio_transcript.done":
|
|
|
886 |
)
|
887 |
|
888 |
app = FastAPI()
|
889 |
+
|
890 |
+
# Mount the stream
|
891 |
stream.mount(app)
|
892 |
|
893 |
+
# Intercept the POST body before it reaches the stream
|
894 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
895 |
+
from starlette.requests import Request as StarletteRequest
|
896 |
+
import io
|
897 |
|
898 |
+
class WebSearchMiddleware(BaseHTTPMiddleware):
|
899 |
+
async def dispatch(self, request: StarletteRequest, call_next):
|
900 |
+
if request.url.path == "/webrtc/offer" and request.method == "POST":
|
901 |
+
# Read the body
|
902 |
+
body = await request.body()
|
903 |
+
|
904 |
+
# Parse it
|
905 |
+
try:
|
906 |
+
data = json.loads(body)
|
907 |
+
webrtc_id = data.get("webrtc_id")
|
908 |
+
web_search_enabled = data.get("web_search_enabled", False)
|
909 |
+
|
910 |
+
print(f"Middleware intercepted - webrtc_id: {webrtc_id}, web_search_enabled: {web_search_enabled}")
|
911 |
+
|
912 |
+
# Store settings globally
|
913 |
+
global current_webrtc_id, current_web_search_enabled
|
914 |
+
if webrtc_id:
|
915 |
+
current_webrtc_id = webrtc_id
|
916 |
+
current_web_search_enabled = web_search_enabled
|
917 |
+
web_search_settings[webrtc_id] = web_search_enabled
|
918 |
+
except:
|
919 |
+
pass
|
920 |
+
|
921 |
+
# Recreate the request with the body
|
922 |
+
request._body = body
|
923 |
+
|
924 |
+
response = await call_next(request)
|
925 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
926 |
|
927 |
+
# Add middleware
|
928 |
+
app.add_middleware(WebSearchMiddleware)
|
929 |
|
930 |
@app.get("/outputs")
|
931 |
async def outputs(webrtc_id: str):
|