File size: 5,746 Bytes
e5687b3 9fbe5ee e5687b3 9fbe5ee e5687b3 9fbe5ee 63cbcdb e5687b3 9fbe5ee e5687b3 9fbe5ee e5687b3 63cbcdb e5687b3 8892010 81420d9 9fbe5ee e5687b3 63cbcdb e5687b3 9fbe5ee 81420d9 e5687b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import streamlit as st
import instructor
from atomic_agents.lib.components.agent_memory import AgentMemory
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator
from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema
from dotenv import load_dotenv
import asyncio
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load environment variables (optional for Hugging Face Secrets)
load_dotenv()
# Initialize Streamlit app
st.title("Math Reasoning Chatbot")
st.write("Select a provider and chat with the bot to solve math problems!")
# Function to set up the client based on the chosen provider
def setup_client(provider):
if provider == "openai":
from openai import AsyncOpenAI
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
st.warning("OpenAI unavailable: OPENAI_API_KEY not set. Using Ollama.")
return setup_client("ollama")
client = instructor.from_openai(AsyncOpenAI(api_key=api_key))
model = "gpt-4o-mini"
display_model = "OpenAI (gpt-4o-mini)"
elif provider == "ollama":
from openai import AsyncOpenAI as OllamaClient
try:
client = instructor.from_openai(
OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON
)
model = "llama3"
display_model = "Ollama (llama3)"
logger.info("Ollama client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize Ollama client: {e}")
st.error(f"Ollama connection failed: {e}")
return None, None, None
else:
st.error(f"Unsupported provider: {provider}")
return None, None, None
return client, model, display_model
# Custom system prompt
system_prompt_generator = SystemPromptGenerator(
background=["You are a math genius."],
steps=["Think logically step by step and solve a math problem."],
output_instructions=[
"Summarise your lengthy thinking processes into experienced problems and solutions with thinking order numbers. Do not speak of all the processes.",
"Answer in plain English plus formulas.",
"Always respond using the proper JSON schema.",
"Always use the available additional information and context to enhance the response.",
],
)
# Provider selection
providers_list = ["ollama", "openai"]
selected_provider = st.selectbox("Choose a provider:", providers_list, key="provider_select")
# Set up client and agent based on the selected provider
client, model, display_model = setup_client(selected_provider)
if client is None:
st.stop()
# Initialize or update the agent
st.session_state.display_model = display_model
if "agent" not in st.session_state or st.session_state.get("current_model") != model:
if "memory" not in st.session_state:
st.session_state.memory = AgentMemory()
initial_message = BaseAgentOutputSchema(chat_message="Hello! I'm here to help with math problems. What can I assist you with today?")
st.session_state.memory.add_message("assistant", initial_message)
st.session_state.conversation = [("assistant", initial_message.chat_message)]
st.session_state.agent = BaseAgent(config=BaseAgentConfig(
client=client,
model=model,
system_prompt_generator=system_prompt_generator,
memory=st.session_state.memory,
system_role="developer",
))
st.session_state.current_model = model
# Display the selected model
st.markdown(f"**Selected Model:** {st.session_state.display_model}")
# Display the system prompt in an expander
with st.expander("View System Prompt"):
system_prompt = system_prompt_generator.generate_prompt()
st.text(system_prompt)
# Display conversation history using st.chat_message
for role, message in st.session_state.conversation:
with st.chat_message(role):
st.markdown(message)
# User input using st.chat_input
user_input = st.chat_input(placeholder="e.g., x^4 + a^4 = 0 find cf")
# Process the input and stream the response
if user_input:
# Add user message to conversation and memory
st.session_state.conversation.append(("user", user_input))
input_schema = BaseAgentInputSchema(chat_message=user_input)
st.session_state.memory.add_message("user", input_schema)
# Display user message immediately
with st.chat_message("user"):
st.markdown(user_input)
# Stream the response
with st.chat_message("assistant"):
response_container = st.empty()
async def stream_response():
current_response = ""
try:
async for partial_response in st.session_state.agent.run_async(input_schema):
if hasattr(partial_response, "chat_message") and partial_response.chat_message:
if partial_response.chat_message != current_response:
current_response = partial_response.chat_message
response_container.markdown(current_response)
except Exception as e:
logger.error(f"Error streaming response: {e}")
response_container.error(f"Error: {e}")
# After streaming completes, add the final response to conversation and memory
st.session_state.conversation.append(("assistant", current_response))
st.session_state.memory.add_message("assistant", BaseAgentOutputSchema(chat_message=current_response))
# Run the async function
asyncio.run(stream_response()) |