|
import os |
|
import streamlit as st |
|
import instructor |
|
from atomic_agents.lib.components.agent_memory import AgentMemory |
|
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator |
|
from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema |
|
from dotenv import load_dotenv |
|
import asyncio |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
st.title("Math Reasoning Chatbot") |
|
st.write("Select a provider and chat with the bot to solve math problems!") |
|
|
|
|
|
def setup_client(provider): |
|
if provider == "openai": |
|
from openai import AsyncOpenAI |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
st.warning("OpenAI unavailable: OPENAI_API_KEY not set. Using Ollama.") |
|
return setup_client("ollama") |
|
client = instructor.from_openai(AsyncOpenAI(api_key=api_key)) |
|
model = "gpt-4o-mini" |
|
display_model = "OpenAI (gpt-4o-mini)" |
|
elif provider == "ollama": |
|
from openai import AsyncOpenAI as OllamaClient |
|
try: |
|
client = instructor.from_openai( |
|
OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON |
|
) |
|
model = "llama3" |
|
display_model = "Ollama (llama3)" |
|
logger.info("Ollama client initialized successfully") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize Ollama client: {e}") |
|
st.error(f"Ollama connection failed: {e}") |
|
return None, None, None |
|
else: |
|
st.error(f"Unsupported provider: {provider}") |
|
return None, None, None |
|
return client, model, display_model |
|
|
|
|
|
system_prompt_generator = SystemPromptGenerator( |
|
background=["You are a math genius."], |
|
steps=["Think logically step by step and solve a math problem."], |
|
output_instructions=[ |
|
"Summarise your lengthy thinking processes into experienced problems and solutions with thinking order numbers. Do not speak of all the processes.", |
|
"Answer in plain English plus formulas.", |
|
"Always respond using the proper JSON schema.", |
|
"Always use the available additional information and context to enhance the response.", |
|
], |
|
) |
|
|
|
|
|
providers_list = ["ollama", "openai"] |
|
selected_provider = st.selectbox("Choose a provider:", providers_list, key="provider_select") |
|
|
|
|
|
client, model, display_model = setup_client(selected_provider) |
|
if client is None: |
|
st.stop() |
|
|
|
|
|
st.session_state.display_model = display_model |
|
if "agent" not in st.session_state or st.session_state.get("current_model") != model: |
|
if "memory" not in st.session_state: |
|
st.session_state.memory = AgentMemory() |
|
initial_message = BaseAgentOutputSchema(chat_message="Hello! I'm here to help with math problems. What can I assist you with today?") |
|
st.session_state.memory.add_message("assistant", initial_message) |
|
st.session_state.conversation = [("assistant", initial_message.chat_message)] |
|
st.session_state.agent = BaseAgent(config=BaseAgentConfig( |
|
client=client, |
|
model=model, |
|
system_prompt_generator=system_prompt_generator, |
|
memory=st.session_state.memory, |
|
system_role="developer", |
|
)) |
|
st.session_state.current_model = model |
|
|
|
|
|
st.markdown(f"**Selected Model:** {st.session_state.display_model}") |
|
|
|
|
|
with st.expander("View System Prompt"): |
|
system_prompt = system_prompt_generator.generate_prompt() |
|
st.text(system_prompt) |
|
|
|
|
|
for role, message in st.session_state.conversation: |
|
with st.chat_message(role): |
|
st.markdown(message) |
|
|
|
|
|
user_input = st.chat_input(placeholder="e.g., x^4 + a^4 = 0 find cf") |
|
|
|
|
|
if user_input: |
|
|
|
st.session_state.conversation.append(("user", user_input)) |
|
input_schema = BaseAgentInputSchema(chat_message=user_input) |
|
st.session_state.memory.add_message("user", input_schema) |
|
|
|
|
|
with st.chat_message("user"): |
|
st.markdown(user_input) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
response_container = st.empty() |
|
async def stream_response(): |
|
current_response = "" |
|
try: |
|
async for partial_response in st.session_state.agent.run_async(input_schema): |
|
if hasattr(partial_response, "chat_message") and partial_response.chat_message: |
|
if partial_response.chat_message != current_response: |
|
current_response = partial_response.chat_message |
|
response_container.markdown(current_response) |
|
except Exception as e: |
|
logger.error(f"Error streaming response: {e}") |
|
response_container.error(f"Error: {e}") |
|
|
|
|
|
st.session_state.conversation.append(("assistant", current_response)) |
|
st.session_state.memory.add_message("assistant", BaseAgentOutputSchema(chat_message=current_response)) |
|
|
|
|
|
asyncio.run(stream_response()) |