Upload 2 files
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from atomic_agents.lib.components.system_prompt_generator import SystemPromptGen
|
|
6 |
from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema
|
7 |
from dotenv import load_dotenv
|
8 |
import asyncio
|
|
|
9 |
|
10 |
# Load environment variables
|
11 |
load_dotenv()
|
@@ -14,6 +15,15 @@ load_dotenv()
|
|
14 |
st.title("Math Reasoning Chatbot")
|
15 |
st.write("Select a provider and chat with the bot to solve math problems!")
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# Function to set up the client based on the chosen provider
|
18 |
def setup_client(provider):
|
19 |
if provider == "openai":
|
@@ -27,23 +37,15 @@ def setup_client(provider):
|
|
27 |
display_model = "OpenAI (gpt-4o-mini)"
|
28 |
elif provider == "ollama":
|
29 |
from openai import AsyncOpenAI as OllamaClient
|
|
|
|
|
|
|
|
|
30 |
client = instructor.from_openai(
|
31 |
OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON
|
32 |
)
|
33 |
model = "llama3"
|
34 |
display_model = "Ollama (llama3)"
|
35 |
-
# elif provider == "gemini":
|
36 |
-
# from openai import AsyncOpenAI
|
37 |
-
# api_key = os.getenv("GEMINI_API_KEY")
|
38 |
-
# if not api_key:
|
39 |
-
# st.error("GEMINI_API_KEY not set in environment variables.")
|
40 |
-
# return None, None, None
|
41 |
-
# client = instructor.from_openai(
|
42 |
-
# AsyncOpenAI(api_key=api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/"),
|
43 |
-
# mode=instructor.Mode.JSON,
|
44 |
-
# )
|
45 |
-
# model = "gemini-2.0-flash-exp"
|
46 |
-
# display_model = "Gemini (gemini-2.0-flash-exp)"
|
47 |
else:
|
48 |
st.error(f"Unsupported provider: {provider}")
|
49 |
return None, None, None
|
@@ -119,15 +121,18 @@ if user_input:
|
|
119 |
response_container = st.empty()
|
120 |
async def stream_response():
|
121 |
current_response = ""
|
122 |
-
|
123 |
-
|
124 |
-
if partial_response
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
131 |
|
132 |
# Run the async function
|
133 |
asyncio.run(stream_response())
|
|
|
6 |
from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema
|
7 |
from dotenv import load_dotenv
|
8 |
import asyncio
|
9 |
+
import httpx
|
10 |
|
11 |
# Load environment variables
|
12 |
load_dotenv()
|
|
|
15 |
st.title("Math Reasoning Chatbot")
|
16 |
st.write("Select a provider and chat with the bot to solve math problems!")
|
17 |
|
18 |
+
# Function to check if the Ollama server is running
|
19 |
+
async def check_ollama_health():
|
20 |
+
try:
|
21 |
+
async with httpx.AsyncClient() as client:
|
22 |
+
response = await client.get("http://localhost:11434/v1")
|
23 |
+
return response.status_code == 200
|
24 |
+
except httpx.RequestError:
|
25 |
+
return False
|
26 |
+
|
27 |
# Function to set up the client based on the chosen provider
|
28 |
def setup_client(provider):
|
29 |
if provider == "openai":
|
|
|
37 |
display_model = "OpenAI (gpt-4o-mini)"
|
38 |
elif provider == "ollama":
|
39 |
from openai import AsyncOpenAI as OllamaClient
|
40 |
+
# Check if Ollama server is running
|
41 |
+
if not asyncio.run(check_ollama_health()):
|
42 |
+
st.error("Ollama server is not running or not accessible at http://localhost:11434. Please try again later or select a different provider.")
|
43 |
+
return None, None, None
|
44 |
client = instructor.from_openai(
|
45 |
OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON
|
46 |
)
|
47 |
model = "llama3"
|
48 |
display_model = "Ollama (llama3)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
else:
|
50 |
st.error(f"Unsupported provider: {provider}")
|
51 |
return None, None, None
|
|
|
121 |
response_container = st.empty()
|
122 |
async def stream_response():
|
123 |
current_response = ""
|
124 |
+
try:
|
125 |
+
async for partial_response in st.session_state.agent.run_async(input_schema):
|
126 |
+
if hasattr(partial_response, "chat_message") and partial_response.chat_message:
|
127 |
+
if partial_response.chat_message != current_response:
|
128 |
+
current_response = partial_response.chat_message
|
129 |
+
response_container.markdown(current_response)
|
130 |
+
|
131 |
+
# After streaming completes, add the final response to conversation and memory
|
132 |
+
st.session_state.conversation.append(("assistant", current_response))
|
133 |
+
st.session_state.memory.add_message("assistant", BaseAgentOutputSchema(chat_message=current_response))
|
134 |
+
except instructor.exceptions.InstructorRetryException as e:
|
135 |
+
response_container.error(f"Failed to connect to the model: {str(e)}. Please try again or select a different provider.")
|
136 |
|
137 |
# Run the async function
|
138 |
asyncio.run(stream_response())
|