mjschock commited on
Commit
9bd791c
·
unverified ·
1 Parent(s): f622879

Enhance AgentRunner and graph functionality by introducing memory management and improved state handling. Update __call__ method to support both question input and resuming from interrupts, while adding new memory-related fields to track context, actions, and success/error counts. Refactor step callback logic for better user interaction and state management.

Browse files
Files changed (3) hide show
  1. agent.py +42 -20
  2. graph.py +62 -18
  3. test_agent.py +1 -0
agent.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  import uuid
 
4
 
5
  from graph import agent_graph
6
 
@@ -27,34 +28,55 @@ class AgentRunner:
27
  logger.info("Initializing AgentRunner")
28
  self.graph = agent_graph
29
  self.last_state = None # Store the last state for testing/debugging
 
30
 
31
- def __call__(self, question: str) -> str:
32
  """Process a question through the agent graph and return the answer.
33
 
34
  Args:
35
- question: The question to process
36
 
37
  Returns:
38
  str: The agent's response
39
  """
40
  try:
41
- logger.info(f"Processing question: {question}")
42
- initial_state = {
43
- "question": question,
44
- "messages": [],
45
- "answer": None,
46
- "step_logs": [],
47
- "is_complete": False, # Initialize is_complete
48
- "step_count": 0, # Initialize step_count
49
- }
50
-
51
- # Generate a unique thread_id for this interaction
52
- thread_id = str(uuid.uuid4())
53
- config = {"configurable": {"thread_id": thread_id}}
54
-
55
- final_state = self.graph.invoke(initial_state, config)
56
- self.last_state = final_state # Store the final state
57
- return final_state.get("answer", "No answer generated")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
- logger.error(f"Error processing question: {str(e)}")
60
  raise
 
1
  import logging
2
  import os
3
  import uuid
4
+ from langgraph.types import Command
5
 
6
  from graph import agent_graph
7
 
 
28
  logger.info("Initializing AgentRunner")
29
  self.graph = agent_graph
30
  self.last_state = None # Store the last state for testing/debugging
31
+ self.thread_id = str(uuid.uuid4()) # Generate a unique thread_id for this runner
32
 
33
+ def __call__(self, input_data) -> str:
34
  """Process a question through the agent graph and return the answer.
35
 
36
  Args:
37
+ input_data: Either a question string or a Command object for resuming
38
 
39
  Returns:
40
  str: The agent's response
41
  """
42
  try:
43
+ config = {"configurable": {"thread_id": self.thread_id}}
44
+
45
+ if isinstance(input_data, str):
46
+ # Initial question
47
+ logger.info(f"Processing question: {input_data}")
48
+ initial_state = {
49
+ "question": input_data,
50
+ "messages": [],
51
+ "answer": None,
52
+ "step_logs": [],
53
+ "is_complete": False,
54
+ "step_count": 0,
55
+ # Initialize new memory fields
56
+ "context": {},
57
+ "memory_buffer": [],
58
+ "last_action": None,
59
+ "action_history": [],
60
+ "error_count": 0,
61
+ "success_count": 0,
62
+ }
63
+
64
+ # Use stream to get interrupt information
65
+ for chunk in self.graph.stream(initial_state, config):
66
+ if isinstance(chunk, tuple) and len(chunk) > 0 and hasattr(chunk[0], '__interrupt__'):
67
+ # If we hit an interrupt, resume with 'c'
68
+ for result in self.graph.stream(Command(resume="c"), config):
69
+ self.last_state = result
70
+ return result.get("answer", "No answer generated")
71
+ self.last_state = chunk
72
+ return chunk.get("answer", "No answer generated")
73
+ else:
74
+ # Resuming from interrupt
75
+ logger.info("Resuming from interrupt")
76
+ for result in self.graph.stream(input_data, config):
77
+ self.last_state = result
78
+ return result.get("answer", "No answer generated")
79
+
80
  except Exception as e:
81
+ logger.error(f"Error processing input: {str(e)}")
82
  raise
graph.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import logging
4
  import os
 
5
  from typing import Dict, List, Optional, TypedDict, Union
6
 
7
  import yaml
@@ -66,6 +67,13 @@ class AgentState(TypedDict):
66
  step_logs: List[Dict]
67
  is_complete: bool
68
  step_count: int
 
 
 
 
 
 
 
69
 
70
 
71
  class AgentNode:
@@ -92,18 +100,43 @@ class AgentNode:
92
  # Log execution start
93
  logger.info("Starting agent execution")
94
 
95
- # Run the agent
96
- result = self.agent.run(state["question"])
97
-
98
- # Log result
99
- logger.info(f"Agent execution result type: {type(result)}")
100
- logger.info(f"Agent execution result value: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- # Update state
103
- new_state = state.copy()
104
- new_state["messages"].append(AIMessage(content=result))
105
- new_state["answer"] = result
106
- new_state["step_count"] += 1
 
 
 
107
 
108
  # Log updated state
109
  logger.info("Updated state after processing:")
@@ -134,10 +167,11 @@ class StepCallbackNode:
134
  state["step_logs"].append(step_log)
135
 
136
  try:
137
- # Use interrupt for user input
138
- user_input = interrupt(
139
  "Press 'c' to continue, 'q' to quit, or 'i' for more info: "
140
  )
 
141
 
142
  if user_input.lower() == "q":
143
  state["is_complete"] = True
@@ -146,12 +180,12 @@ class StepCallbackNode:
146
  logger.info(f"Current step: {state['step_count']}")
147
  logger.info(f"Question: {state['question']}")
148
  logger.info(f"Current answer: {state['answer']}")
149
- return self(state, config) # Recursively call for new input
150
  elif user_input.lower() == "c":
151
  return state
152
  else:
153
  logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
154
- return self(state, config) # Recursively call for new input
155
 
156
  except Exception as e:
157
  logger.warning(f"Error during interrupt: {str(e)}")
@@ -169,10 +203,20 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
169
 
170
  # Add edges
171
  workflow.add_edge("agent", "callback")
 
 
 
 
 
 
 
 
 
 
 
 
172
  workflow.add_conditional_edges(
173
- "callback",
174
- lambda x: END if x["is_complete"] else "agent",
175
- {True: END, False: "agent"},
176
  )
177
 
178
  # Set entry point
 
2
 
3
  import logging
4
  import os
5
+ from datetime import datetime
6
  from typing import Dict, List, Optional, TypedDict, Union
7
 
8
  import yaml
 
67
  step_logs: List[Dict]
68
  is_complete: bool
69
  step_count: int
70
+ # Add memory-related fields
71
+ context: Dict[str, any] # For storing contextual information
72
+ memory_buffer: List[Dict] # For storing important information across steps
73
+ last_action: Optional[str] # Track the last action taken
74
+ action_history: List[Dict] # History of actions taken
75
+ error_count: int # Track error frequency
76
+ success_count: int # Track successful operations
77
 
78
 
79
  class AgentNode:
 
100
  # Log execution start
101
  logger.info("Starting agent execution")
102
 
103
+ try:
104
+ # Run the agent
105
+ result = self.agent.run(state["question"])
106
+
107
+ # Update memory-related fields
108
+ new_state = state.copy()
109
+ new_state["messages"].append(AIMessage(content=result))
110
+ new_state["answer"] = result
111
+ new_state["step_count"] += 1
112
+ new_state["last_action"] = "agent_response"
113
+ new_state["action_history"].append(
114
+ {
115
+ "step": state["step_count"],
116
+ "action": "agent_response",
117
+ "result": result,
118
+ }
119
+ )
120
+ new_state["success_count"] += 1
121
+
122
+ # Store important information in memory buffer
123
+ if result:
124
+ new_state["memory_buffer"].append(
125
+ {
126
+ "step": state["step_count"],
127
+ "content": result,
128
+ "timestamp": datetime.now().isoformat(),
129
+ }
130
+ )
131
 
132
+ except Exception as e:
133
+ logger.error(f"Error during agent execution: {str(e)}")
134
+ new_state = state.copy()
135
+ new_state["error_count"] += 1
136
+ new_state["action_history"].append(
137
+ {"step": state["step_count"], "action": "error", "error": str(e)}
138
+ )
139
+ raise
140
 
141
  # Log updated state
142
  logger.info("Updated state after processing:")
 
167
  state["step_logs"].append(step_log)
168
 
169
  try:
170
+ # Use interrupt for user input and unpack the tuple
171
+ interrupt_result = interrupt(
172
  "Press 'c' to continue, 'q' to quit, or 'i' for more info: "
173
  )
174
+ user_input = interrupt_result[0] # Get the actual user input
175
 
176
  if user_input.lower() == "q":
177
  state["is_complete"] = True
 
180
  logger.info(f"Current step: {state['step_count']}")
181
  logger.info(f"Question: {state['question']}")
182
  logger.info(f"Current answer: {state['answer']}")
183
+ return state
184
  elif user_input.lower() == "c":
185
  return state
186
  else:
187
  logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
188
+ return state
189
 
190
  except Exception as e:
191
  logger.warning(f"Error during interrupt: {str(e)}")
 
203
 
204
  # Add edges
205
  workflow.add_edge("agent", "callback")
206
+
207
+ # Add conditional edges for callback
208
+ def should_continue(state: AgentState) -> str:
209
+ """Determine the next node based on state."""
210
+ if state["is_complete"]:
211
+ return END
212
+ # If we have an answer and no errors, we're done
213
+ if state["answer"] and state["error_count"] == 0:
214
+ return END
215
+ # Otherwise continue to agent
216
+ return "agent"
217
+
218
  workflow.add_conditional_edges(
219
+ "callback", should_continue, {END: END, "agent": "agent"}
 
 
220
  )
221
 
222
  # Set entry point
test_agent.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
 
3
  import pytest
 
4
 
5
  from agent import AgentRunner
6
 
 
1
  import logging
2
 
3
  import pytest
4
+ from langgraph.types import Command
5
 
6
  from agent import AgentRunner
7