import os import streamlit as st import instructor from atomic_agents.lib.components.agent_memory import AgentMemory from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema from dotenv import load_dotenv import asyncio import httpx import time import subprocess # Load environment variables load_dotenv() # Initialize Streamlit app st.title("Math Reasoning Chatbot") st.write("Select a provider and chat with the bot to solve math problems!") # Function to start Ollama server if not already running def start_ollama_server(): try: # Check if Ollama is already running response = httpx.get("http://localhost:11434/v1") if response.status_code == 200: return True except httpx.RequestError: # Start Ollama server in the background subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) return False # Function to check if the Ollama server is running with retries async def check_ollama_health(max_retries=5, retry_delay=5): for attempt in range(max_retries): try: async with httpx.AsyncClient() as client: response = await client.get("http://localhost:11434/v1") if response.status_code == 200: return True except httpx.RequestError: if attempt < max_retries - 1: st.warning(f"Ollama server not yet available (attempt {attempt + 1}/{max_retries}). Retrying in {retry_delay} seconds...") await asyncio.sleep(retry_delay) continue return False # Function to set up the client based on the chosen provider def setup_client(provider): if provider == "openai": from openai import AsyncOpenAI api_key = os.getenv("OPENAI_API_KEY") if not api_key: st.error("OPENAI_API_KEY not set in environment variables.") return None, None, None client = instructor.from_openai(AsyncOpenAI(api_key=api_key)) model = "gpt-4o-mini" display_model = "OpenAI (gpt-4o-mini)" elif provider == "ollama": from openai import AsyncOpenAI as OllamaClient # Start Ollama server if not running if not start_ollama_server(): # Wait and check health if not asyncio.run(check_ollama_health()): st.error("Failed to start Ollama server or connect to it at http://localhost:11434 after multiple attempts.") return None, None, None client = instructor.from_openai( OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON ) model = "llama3.2:1b" display_model = "Ollama (llama3.2:1b)" else: st.error(f"Unsupported provider: {provider}") return None, None, None return client, model, display_model # Custom system prompt system_prompt_generator = SystemPromptGenerator( background=["You are a math genius."], steps=["Think logically step by step and solve a math problem."], output_instructions=[ "Summarise your lengthy thinking processes into experienced problems and solutions with thinking order numbers. Do not speak of all the processes.", "Answer in plain English plus formulas.", "Always respond using the proper JSON schema.", "Always use the available additional information and context to enhance the response.", ], ) # Provider selection providers_list = ["openai", "ollama"] selected_provider = st.selectbox("Choose a provider:", providers_list, key="provider_select") # Set up client and agent based on the selected provider client, model, display_model = setup_client(selected_provider) if client is None: st.stop() # Initialize or update the agent st.session_state.display_model = display_model if "agent" not in st.session_state or st.session_state.get("current_model") != model: if "memory" not in st.session_state: st.session_state.memory = AgentMemory() initial_message = BaseAgentOutputSchema(chat_message="Hello! I'm here to help with math problems. What can I assist you with today?") st.session_state.memory.add_message("assistant", initial_message) st.session_state.conversation = [("assistant", initial_message.chat_message)] st.session_state.agent = BaseAgent(config=BaseAgentConfig( client=client, model=model, system_prompt_generator=system_prompt_generator, memory ~ =st.session_state.memory, system_role="developer", )) st.session_state.current_model = model # Track the current model to detect changes # Display the selected model st.markdown(f"**Selected Model:** {st.session_state.display_model}") # Display the system prompt in an expander with st.expander("View System Prompt"): system_prompt = system_prompt_generator.generate_prompt() st.text(system_prompt) # Display conversation history using st.chat_message for role, message in st.session_state.conversation: with st.chat_message(role): st.markdown(message) # User input using st.chat_input user_input = st.chat_input(placeholder="e.g., x^4 + a^4 = 0 find cf") # Process the input and stream the response if user_input: # Add user message to conversation and memory st.session_state.conversation.append(("user", user_input)) input_schema = BaseAgentInputSchema(chat_message=user_input) st.session_state.memory.add_message("user", input_schema) # Display user message immediately with st.chat_message("user"): st.markdown(user_input) # Stream the response with st.chat_message("assistant"): response_container = st.empty() async def stream_response(): current_response = "" try: async for partial_response in st.session_state.agent.run_async(input_schema): if hasattr(partial_response, "chat_message") and partial_response.chat_message: if partial_response.chat_message != current_response: current_response = partial_response.chat_message response_container.markdown(current_response) # After streaming completes, add the final response to conversation and memory st.session_state.conversation.append(("assistant", current_response)) st.session_state.memory.add_message("assistant", BaseAgentOutputSchema(chat_message=current_response)) except instructor.exceptions.InstructorRetryException as e: response_container.error(f"Failed to connect to the model: {str(e)}. Please try again or select a different provider.") # Run the async function asyncio.run(stream_response())