File size: 6,947 Bytes
e5687b3
 
 
 
 
 
 
 
ec01080
71aa5c6
31011ed
e5687b3
 
 
 
 
 
 
 
31011ed
 
 
 
 
 
 
 
 
 
 
 
71aa5c6
 
 
 
 
 
 
 
 
 
 
 
 
 
ec01080
e5687b3
 
 
 
 
 
 
 
63cbcdb
e5687b3
 
 
 
31011ed
 
 
 
 
 
63cbcdb
 
 
31011ed
ec7800e
e5687b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63cbcdb
e5687b3
 
 
 
 
31011ed
 
 
e5687b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63cbcdb
e5687b3
 
 
 
 
 
 
 
 
 
ec01080
 
 
 
 
 
 
 
 
 
 
 
e5687b3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import streamlit as st
import instructor
from atomic_agents.lib.components.agent_memory import AgentMemory
from atomic_agents.lib.components.system_prompt_generator import SystemPromptGenerator
from atomic_agents.agents.base_agent import BaseAgent, BaseAgentConfig, BaseAgentInputSchema, BaseAgentOutputSchema
from dotenv import load_dotenv
import asyncio
import httpx
import time
import subprocess

# Load environment variables
load_dotenv()

# Initialize Streamlit app
st.title("Math Reasoning Chatbot")
st.write("Select a provider and chat with the bot to solve math problems!")

# Function to start Ollama server if not already running
def start_ollama_server():
    try:
        # Check if Ollama is already running
        response = httpx.get("http://localhost:11434/v1")
        if response.status_code == 200:
            return True
    except httpx.RequestError:
        # Start Ollama server in the background
        subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        return False

# Function to check if the Ollama server is running with retries
async def check_ollama_health(max_retries=5, retry_delay=5):
    for attempt in range(max_retries):
        try:
            async with httpx.AsyncClient() as client:
                response = await client.get("http://localhost:11434/v1")
                if response.status_code == 200:
                    return True
        except httpx.RequestError:
            if attempt < max_retries - 1:
                st.warning(f"Ollama server not yet available (attempt {attempt + 1}/{max_retries}). Retrying in {retry_delay} seconds...")
                await asyncio.sleep(retry_delay)
            continue
    return False

# Function to set up the client based on the chosen provider
def setup_client(provider):
    if provider == "openai":
        from openai import AsyncOpenAI
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            st.error("OPENAI_API_KEY not set in environment variables.")
            return None, None, None
        client = instructor.from_openai(AsyncOpenAI(api_key=api_key))
        model = "gpt-4o-mini"
        display_model = "OpenAI (gpt-4o-mini)"
    elif provider == "ollama":
        from openai import AsyncOpenAI as OllamaClient
        # Start Ollama server if not running
        if not start_ollama_server():
            # Wait and check health
            if not asyncio.run(check_ollama_health()):
                st.error("Failed to start Ollama server or connect to it at http://localhost:11434 after multiple attempts.")
                return None, None, None
        client = instructor.from_openai(
            OllamaClient(base_url="http://localhost:11434/v1", api_key="ollama"), mode=instructor.Mode.JSON
        )
        model = "llama3.2:1b"
        display_model = "Ollama (llama3.2:1b)"
    else:
        st.error(f"Unsupported provider: {provider}")
        return None, None, None
    return client, model, display_model

# Custom system prompt
system_prompt_generator = SystemPromptGenerator(
    background=["You are a math genius."],
    steps=["Think logically step by step and solve a math problem."],
    output_instructions=[
        "Summarise your lengthy thinking processes into experienced problems and solutions with thinking order numbers. Do not speak of all the processes.",
        "Answer in plain English plus formulas.",
        "Always respond using the proper JSON schema.",
        "Always use the available additional information and context to enhance the response.",
    ],
)

# Provider selection
providers_list = ["openai", "ollama"]
selected_provider = st.selectbox("Choose a provider:", providers_list, key="provider_select")

# Set up client and agent based on the selected provider
client, model, display_model = setup_client(selected_provider)
if client is None:
    st.stop()

# Initialize or update the agent
st.session_state.display_model = display_model
if "agent" not in st.session_state or st.session_state.get("current_model") != model:
    if "memory" not in st.session_state:
        st.session_state.memory = AgentMemory()
        initial_message = BaseAgentOutputSchema(chat_message="Hello! I'm here to help with math problems. What can I assist you with today?")
        st.session_state.memory.add_message("assistant", initial_message)
        st.session_state.conversation = [("assistant", initial_message.chat_message)]
    st.session_state.agent = BaseAgent(config=BaseAgentConfig(
        client=client,
        model=model,
        system_prompt_generator=system_prompt_generator,
        memory ~

=st.session_state.memory,
        system_role="developer",
    ))
    st.session_state.current_model = model  # Track the current model to detect changes

# Display the selected model
st.markdown(f"**Selected Model:** {st.session_state.display_model}")

# Display the system prompt in an expander
with st.expander("View System Prompt"):
    system_prompt = system_prompt_generator.generate_prompt()
    st.text(system_prompt)

# Display conversation history using st.chat_message
for role, message in st.session_state.conversation:
    with st.chat_message(role):
        st.markdown(message)

# User input using st.chat_input
user_input = st.chat_input(placeholder="e.g., x^4 + a^4 = 0 find cf")

# Process the input and stream the response
if user_input:
    # Add user message to conversation and memory
    st.session_state.conversation.append(("user", user_input))
    input_schema = BaseAgentInputSchema(chat_message=user_input)
    st.session_state.memory.add_message("user", input_schema)

    # Display user message immediately
    with st.chat_message("user"):
        st.markdown(user_input)

    # Stream the response
    with st.chat_message("assistant"):
        response_container = st.empty()
        async def stream_response():
            current_response = ""
            try:
                async for partial_response in st.session_state.agent.run_async(input_schema):
                    if hasattr(partial_response, "chat_message") and partial_response.chat_message:
                        if partial_response.chat_message != current_response:
                            current_response = partial_response.chat_message
                            response_container.markdown(current_response)

                # After streaming completes, add the final response to conversation and memory
                st.session_state.conversation.append(("assistant", current_response))
                st.session_state.memory.add_message("assistant", BaseAgentOutputSchema(chat_message=current_response))
            except instructor.exceptions.InstructorRetryException as e:
                response_container.error(f"Failed to connect to the model: {str(e)}. Please try again or select a different provider.")

        # Run the async function
        asyncio.run(stream_response())