Spaces:
Build error
Build error
Add configuration, graph, runner, and tools modules to enhance agent functionality. Introduce a Configuration class for managing parameters, implement an AgentRunner for executing the agent graph, and create tools for general search and mathematical calculations. Update test_agent.py to reflect new import paths and improve overall code organization.
Browse files- services/configuration.py β configuration.py +0 -0
- services/graph.py β graph.py +52 -48
- api/runner.py β runner.py +59 -3
- test_agent.py +1 -1
- services/tools.py β tools.py +30 -0
services/configuration.py β configuration.py
RENAMED
File without changes
|
services/graph.py β graph.py
RENAMED
@@ -6,14 +6,14 @@ from datetime import datetime
|
|
6 |
from typing import Dict, List, Optional, TypedDict, Union
|
7 |
|
8 |
import yaml
|
9 |
-
from services.configuration import Configuration
|
10 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
11 |
from langchain_core.runnables import RunnableConfig
|
12 |
from langgraph.graph import END, StateGraph
|
13 |
from langgraph.types import interrupt
|
14 |
from smolagents import CodeAgent, LiteLLMModel
|
15 |
|
16 |
-
from
|
|
|
17 |
|
18 |
# Configure logging
|
19 |
logging.basicConfig(level=logging.INFO)
|
@@ -33,7 +33,7 @@ else:
|
|
33 |
litellm.drop_params = True
|
34 |
|
35 |
# Load default prompt templates from local file
|
36 |
-
current_dir = os.path.dirname(os.path.
|
37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
39 |
|
@@ -150,48 +150,50 @@ class AgentNode:
|
|
150 |
class StepCallbackNode:
|
151 |
"""Node that handles step callbacks and user interaction."""
|
152 |
|
153 |
-
def
|
154 |
-
self
|
155 |
-
) -> AgentState:
|
156 |
-
"""Handle step callback and user interaction."""
|
157 |
-
# Get configuration
|
158 |
-
cfg = Configuration.from_runnable_config(config)
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
"answer": state["answer"],
|
166 |
-
}
|
167 |
-
state["step_logs"].append(step_log)
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
)
|
174 |
-
user_input = interrupt_result[0] # Get the actual user input
|
175 |
|
176 |
-
if
|
|
|
177 |
state["is_complete"] = True
|
178 |
return state
|
179 |
-
elif
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
return state
|
184 |
-
elif
|
185 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
return state
|
187 |
else:
|
188 |
-
|
189 |
-
return state
|
190 |
-
|
191 |
-
except Exception as e:
|
192 |
-
logger.warning(f"Error during interrupt: {str(e)}")
|
193 |
-
# Continue without marking as complete
|
194 |
-
return state
|
195 |
|
196 |
|
197 |
def build_agent_graph(agent: AgentNode) -> StateGraph:
|
@@ -201,7 +203,7 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
201 |
|
202 |
# Add nodes
|
203 |
workflow.add_node("agent", agent)
|
204 |
-
workflow.add_node("callback", StepCallbackNode())
|
205 |
|
206 |
# Add edges
|
207 |
workflow.add_edge("agent", "callback")
|
@@ -209,22 +211,24 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
209 |
# Add conditional edges for callback
|
210 |
def should_continue(state: AgentState) -> str:
|
211 |
"""Determine the next node based on state."""
|
212 |
-
# If we have no answer, continue
|
213 |
if not state["answer"]:
|
214 |
-
logger.info("No answer found, continuing")
|
215 |
-
return "agent"
|
216 |
-
|
217 |
-
# If we have an answer but it's not complete, continue
|
218 |
-
if not state["is_complete"]:
|
219 |
-
logger.info(f"Found answer but not complete: {state['answer']}")
|
220 |
return "agent"
|
221 |
|
222 |
# If we have an answer and it's complete, we're done
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
workflow.add_conditional_edges(
|
227 |
-
"callback",
|
|
|
|
|
228 |
)
|
229 |
|
230 |
# Set entry point
|
|
|
6 |
from typing import Dict, List, Optional, TypedDict, Union
|
7 |
|
8 |
import yaml
|
|
|
9 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
10 |
from langchain_core.runnables import RunnableConfig
|
11 |
from langgraph.graph import END, StateGraph
|
12 |
from langgraph.types import interrupt
|
13 |
from smolagents import CodeAgent, LiteLLMModel
|
14 |
|
15 |
+
from configuration import Configuration
|
16 |
+
from tools import tools
|
17 |
|
18 |
# Configure logging
|
19 |
logging.basicConfig(level=logging.INFO)
|
|
|
33 |
litellm.drop_params = True
|
34 |
|
35 |
# Load default prompt templates from local file
|
36 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
39 |
|
|
|
150 |
class StepCallbackNode:
|
151 |
"""Node that handles step callbacks and user interaction."""
|
152 |
|
153 |
+
def __init__(self, name: str):
|
154 |
+
self.name = name
|
|
|
|
|
|
|
|
|
155 |
|
156 |
+
def __call__(self, state: dict) -> dict:
|
157 |
+
"""Process the state and handle user interaction."""
|
158 |
+
print(f"\nCurrent step: {state.get('step_count', 0)}")
|
159 |
+
print(f"Question: {state.get('question', 'No question')}")
|
160 |
+
print(f"Current answer: {state.get('answer', 'No answer yet')}\n")
|
|
|
|
|
|
|
161 |
|
162 |
+
while True:
|
163 |
+
choice = input(
|
164 |
+
"Enter 'c' to continue, 'q' to quit, 'i' for more info, or 'r' to reject answer: "
|
165 |
+
).lower()
|
|
|
|
|
166 |
|
167 |
+
if choice == "c":
|
168 |
+
# Mark as complete to continue
|
169 |
state["is_complete"] = True
|
170 |
return state
|
171 |
+
elif choice == "q":
|
172 |
+
# Mark as complete and set answer to None to quit
|
173 |
+
state["is_complete"] = True
|
174 |
+
state["answer"] = None
|
175 |
return state
|
176 |
+
elif choice == "i":
|
177 |
+
# Show more information but don't mark as complete
|
178 |
+
print("\nAdditional Information:")
|
179 |
+
print(f"Messages: {state.get('messages', [])}")
|
180 |
+
print(f"Step Logs: {state.get('step_logs', [])}")
|
181 |
+
print(f"Context: {state.get('context', {})}")
|
182 |
+
print(f"Memory Buffer: {state.get('memory_buffer', [])}")
|
183 |
+
print(f"Last Action: {state.get('last_action', None)}")
|
184 |
+
print(f"Action History: {state.get('action_history', [])}")
|
185 |
+
print(f"Error Count: {state.get('error_count', 0)}")
|
186 |
+
print(f"Success Count: {state.get('success_count', 0)}\n")
|
187 |
+
elif choice == "r":
|
188 |
+
# Reject the current answer and continue execution
|
189 |
+
print("\nRejecting current answer and continuing execution...")
|
190 |
+
# Clear the message history to prevent confusion
|
191 |
+
state["messages"] = []
|
192 |
+
state["answer"] = None
|
193 |
+
state["is_complete"] = False
|
194 |
return state
|
195 |
else:
|
196 |
+
print("Invalid choice. Please enter 'c', 'q', 'i', or 'r'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
|
199 |
def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
|
203 |
|
204 |
# Add nodes
|
205 |
workflow.add_node("agent", agent)
|
206 |
+
workflow.add_node("callback", StepCallbackNode("callback"))
|
207 |
|
208 |
# Add edges
|
209 |
workflow.add_edge("agent", "callback")
|
|
|
211 |
# Add conditional edges for callback
|
212 |
def should_continue(state: AgentState) -> str:
|
213 |
"""Determine the next node based on state."""
|
214 |
+
# If we have no answer, continue to agent
|
215 |
if not state["answer"]:
|
216 |
+
logger.info("No answer found, continuing to agent")
|
|
|
|
|
|
|
|
|
|
|
217 |
return "agent"
|
218 |
|
219 |
# If we have an answer and it's complete, we're done
|
220 |
+
if state["is_complete"]:
|
221 |
+
logger.info(f"Found complete answer: {state['answer']}")
|
222 |
+
return END
|
223 |
+
|
224 |
+
# Otherwise, go to callback for user input
|
225 |
+
logger.info(f"Waiting for user input for answer: {state['answer']}")
|
226 |
+
return "callback"
|
227 |
|
228 |
workflow.add_conditional_edges(
|
229 |
+
"callback",
|
230 |
+
should_continue,
|
231 |
+
{END: END, "agent": "agent", "callback": "callback"},
|
232 |
)
|
233 |
|
234 |
# Set entry point
|
api/runner.py β runner.py
RENAMED
@@ -1,10 +1,11 @@
|
|
1 |
import logging
|
2 |
import os
|
|
|
3 |
import uuid
|
4 |
|
5 |
from langgraph.types import Command
|
6 |
|
7 |
-
from
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
@@ -48,6 +49,26 @@ class AgentRunner:
|
|
48 |
if "messages" in state and state["messages"]:
|
49 |
for msg in reversed(state["messages"]):
|
50 |
if hasattr(msg, "content") and msg.content:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
logger.info(f"Found answer in message: {msg.content}")
|
52 |
return msg.content
|
53 |
|
@@ -99,7 +120,9 @@ class AgentRunner:
|
|
99 |
answer = self._extract_answer(chunk)
|
100 |
if answer:
|
101 |
self.last_state = chunk
|
102 |
-
return answer
|
|
|
|
|
103 |
else:
|
104 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
105 |
else:
|
@@ -111,7 +134,9 @@ class AgentRunner:
|
|
111 |
answer = self._extract_answer(result)
|
112 |
if answer:
|
113 |
self.last_state = result
|
114 |
-
return answer
|
|
|
|
|
115 |
else:
|
116 |
logger.debug(f"Skipping result without answer: {result}")
|
117 |
|
@@ -122,3 +147,34 @@ class AgentRunner:
|
|
122 |
except Exception as e:
|
123 |
logger.error(f"Error processing input: {str(e)}")
|
124 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import os
|
3 |
+
import re
|
4 |
import uuid
|
5 |
|
6 |
from langgraph.types import Command
|
7 |
|
8 |
+
from graph import agent_graph
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
|
|
49 |
if "messages" in state and state["messages"]:
|
50 |
for msg in reversed(state["messages"]):
|
51 |
if hasattr(msg, "content") and msg.content:
|
52 |
+
# Look for code blocks that might contain the answer
|
53 |
+
if "```" in msg.content:
|
54 |
+
# Extract code between ```py and ``` or ```python and ```
|
55 |
+
code_match = re.search(
|
56 |
+
r"```(?:py|python)?\s*\n(.*?)\n```", msg.content, re.DOTALL
|
57 |
+
)
|
58 |
+
if code_match:
|
59 |
+
code = code_match.group(1)
|
60 |
+
# Look for final_answer call
|
61 |
+
final_answer_match = re.search(
|
62 |
+
r"final_answer\((.*?)\)", code
|
63 |
+
)
|
64 |
+
if final_answer_match:
|
65 |
+
answer = final_answer_match.group(1)
|
66 |
+
logger.info(
|
67 |
+
f"Found answer in final_answer call: {answer}"
|
68 |
+
)
|
69 |
+
return answer
|
70 |
+
|
71 |
+
# If no code block with final_answer, use the content
|
72 |
logger.info(f"Found answer in message: {msg.content}")
|
73 |
return msg.content
|
74 |
|
|
|
120 |
answer = self._extract_answer(chunk)
|
121 |
if answer:
|
122 |
self.last_state = chunk
|
123 |
+
# If the state is complete, return the answer
|
124 |
+
if chunk.get("is_complete", False):
|
125 |
+
return answer
|
126 |
else:
|
127 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
128 |
else:
|
|
|
134 |
answer = self._extract_answer(result)
|
135 |
if answer:
|
136 |
self.last_state = result
|
137 |
+
# If the state is complete, return the answer
|
138 |
+
if result.get("is_complete", False):
|
139 |
+
return answer
|
140 |
else:
|
141 |
logger.debug(f"Skipping result without answer: {result}")
|
142 |
|
|
|
147 |
except Exception as e:
|
148 |
logger.error(f"Error processing input: {str(e)}")
|
149 |
raise
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
import argparse
|
154 |
+
|
155 |
+
from langgraph.types import Command
|
156 |
+
|
157 |
+
# Set up argument parser
|
158 |
+
parser = argparse.ArgumentParser(description="Run the agent with a question")
|
159 |
+
parser.add_argument("question", type=str, help="The question to ask the agent")
|
160 |
+
parser.add_argument(
|
161 |
+
"--resume",
|
162 |
+
type=str,
|
163 |
+
help="Value to resume with after an interrupt",
|
164 |
+
default=None,
|
165 |
+
)
|
166 |
+
args = parser.parse_args()
|
167 |
+
|
168 |
+
# Create agent runner
|
169 |
+
runner = AgentRunner()
|
170 |
+
|
171 |
+
if args.resume:
|
172 |
+
# Resume from interrupt with provided value
|
173 |
+
print(f"\nResuming with value: {args.resume}")
|
174 |
+
response = runner(Command(resume=args.resume))
|
175 |
+
else:
|
176 |
+
# Initial run with question
|
177 |
+
print(f"\nAsking question: {args.question}")
|
178 |
+
response = runner(args.question)
|
179 |
+
|
180 |
+
print(f"\nFinal response: {response}")
|
test_agent.py
CHANGED
@@ -2,7 +2,7 @@ import logging
|
|
2 |
|
3 |
import pytest
|
4 |
|
5 |
-
from
|
6 |
|
7 |
# Configure test logger
|
8 |
test_logger = logging.getLogger("test_agent")
|
|
|
2 |
|
3 |
import pytest
|
4 |
|
5 |
+
from runner import AgentRunner
|
6 |
|
7 |
# Configure test logger
|
8 |
test_logger = logging.getLogger("test_agent")
|
services/tools.py β tools.py
RENAMED
@@ -47,9 +47,39 @@ class GeneralSearchTool(Tool):
|
|
47 |
return "\n\n---\n\n".join(output)
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Export all tools
|
51 |
tools = [
|
52 |
# DuckDuckGoSearchTool(),
|
53 |
GeneralSearchTool(),
|
|
|
54 |
# WikipediaSearchTool(),
|
55 |
]
|
|
|
47 |
return "\n\n---\n\n".join(output)
|
48 |
|
49 |
|
50 |
+
class MathTool(Tool):
|
51 |
+
name = "math"
|
52 |
+
description = """Performs mathematical calculations and returns the result."""
|
53 |
+
inputs = {
|
54 |
+
"expression": {
|
55 |
+
"type": "string",
|
56 |
+
"description": "The mathematical expression to evaluate.",
|
57 |
+
}
|
58 |
+
}
|
59 |
+
output_type = "string"
|
60 |
+
|
61 |
+
def forward(self, expression: str) -> str:
|
62 |
+
try:
|
63 |
+
# Use eval with a restricted set of builtins for safety
|
64 |
+
safe_dict = {
|
65 |
+
"__builtins__": {
|
66 |
+
"abs": abs,
|
67 |
+
"round": round,
|
68 |
+
"min": min,
|
69 |
+
"max": max,
|
70 |
+
"sum": sum,
|
71 |
+
}
|
72 |
+
}
|
73 |
+
result = eval(expression, safe_dict)
|
74 |
+
return str(result)
|
75 |
+
except Exception as e:
|
76 |
+
raise Exception(f"Error evaluating expression: {str(e)}")
|
77 |
+
|
78 |
+
|
79 |
# Export all tools
|
80 |
tools = [
|
81 |
# DuckDuckGoSearchTool(),
|
82 |
GeneralSearchTool(),
|
83 |
+
MathTool(),
|
84 |
# WikipediaSearchTool(),
|
85 |
]
|