wt002 commited on
Commit
ef60401
·
verified ·
1 Parent(s): 3363a47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -126,18 +126,23 @@ def init_state(question: str) -> AgentState:
126
 
127
 
128
  def should_continue(state: AgentState) -> str:
129
- last_message = state["history"][-1]
130
 
131
- # If the last message is FINAL ANSWER, stop
 
 
 
 
 
132
  if isinstance(last_message, AIMessage) and "FINAL ANSWER:" in last_message.content:
133
  return "end"
134
 
135
- # If an action_request was just added, continue to tool
136
- for msg in reversed(state["history"]):
137
  if isinstance(msg, dict) and msg.get("role") == "action_request":
138
  return "continue"
139
 
140
- # Otherwise, reason again
141
  return "reason"
142
 
143
 
@@ -153,23 +158,28 @@ def reasoning_node(state: AgentState) -> AgentState:
153
  if not GOOGLE_API_KEY:
154
  raise ValueError("GOOGLE_API_KEY not set in environment variables.")
155
 
156
- # Ensure history ends with a HumanMessage
157
- if not state.get("history") or not isinstance(state["history"][-1], HumanMessage):
158
- state["history"] = state.get("history", [])
 
159
  state["history"].append(HumanMessage(content="Continue."))
160
 
161
  # Ensure context is a dictionary
162
  if not isinstance(state.get("context"), dict):
163
  state["context"] = {}
164
 
165
- # Initialize the Gemini model (via LangChain wrapper)
 
 
 
 
166
  llm = ChatGoogleGenerativeAI(
167
  model="gemini-1.5-flash",
168
  temperature=0.1,
169
  google_api_key=GOOGLE_API_KEY
170
  )
171
 
172
- # Create prompt from messages
173
  prompt = ChatPromptTemplate.from_messages([
174
  ("system", (
175
  "You're an expert problem solver. Analyze the question, select the best tool, "
@@ -183,11 +193,11 @@ def reasoning_node(state: AgentState) -> AgentState:
183
  *state["history"]
184
  ])
185
 
186
- # Build and invoke the chain
187
  chain = prompt | llm
188
  response = chain.invoke({
189
  "context": state["context"],
190
- "reasoning": state.get("reasoning", ""),
191
  "question": state["question"]
192
  })
193
 
@@ -196,10 +206,10 @@ def reasoning_node(state: AgentState) -> AgentState:
196
 
197
  # Update state
198
  state["history"].append(AIMessage(content=content))
199
- state["reasoning"] += f"\nStep {state['iterations']+1}: {reasoning}"
200
  state["iterations"] += 1
201
 
202
- # Decide next step based on action
203
  if "final answer" in action.lower():
204
  state["history"].append(AIMessage(content=f"FINAL ANSWER: {action_input}"))
205
  else:
@@ -211,12 +221,17 @@ def reasoning_node(state: AgentState) -> AgentState:
211
  return state
212
 
213
 
 
214
 
215
 
216
  def tool_node(state: AgentState) -> AgentState:
217
  from langchain.schema import AIMessage
218
 
219
- # Get the last tool action request
 
 
 
 
220
  tool_call = None
221
  for msg in reversed(state["history"]):
222
  if isinstance(msg, dict) and msg.get("role") == "action_request":
@@ -226,11 +241,17 @@ def tool_node(state: AgentState) -> AgentState:
226
  if not tool_call:
227
  raise ValueError("No tool call found in history")
228
 
229
- tool_name = tool_call["tool"]
230
- tool_input = tool_call["input"]
 
 
 
 
231
 
232
  # Look up and invoke the tool
233
- tool_fn = next((t for t in BasicAgent().tools if t.__name__ == tool_name), None)
 
 
234
  if tool_fn is None:
235
  raise ValueError(f"Tool '{tool_name}' not found")
236
 
@@ -239,12 +260,13 @@ def tool_node(state: AgentState) -> AgentState:
239
  except Exception as e:
240
  tool_output = f"[Tool Error] {str(e)}"
241
 
242
- # Store tool result as AIMessage
243
  state["history"].append(AIMessage(content=f"[{tool_name} output]\n{tool_output}"))
244
 
245
  return state
246
 
247
 
 
248
  def parse_agent_response(response: str) -> tuple:
249
  """Extract reasoning, action, and input from response"""
250
  reasoning = response.split("Reasoning:")[1].split("Action:")[0].strip()
 
126
 
127
 
128
  def should_continue(state: AgentState) -> str:
129
+ history = state.get("history", [])
130
 
131
+ if not history:
132
+ return "reason" # No history yet, reason first
133
+
134
+ last_message = history[-1]
135
+
136
+ # End if agent has produced a final answer
137
  if isinstance(last_message, AIMessage) and "FINAL ANSWER:" in last_message.content:
138
  return "end"
139
 
140
+ # If an action_request exists, trigger tool use
141
+ for msg in reversed(history):
142
  if isinstance(msg, dict) and msg.get("role") == "action_request":
143
  return "continue"
144
 
145
+ # Otherwise, go back to reasoning
146
  return "reason"
147
 
148
 
 
158
  if not GOOGLE_API_KEY:
159
  raise ValueError("GOOGLE_API_KEY not set in environment variables.")
160
 
161
+ # Ensure history is initialized and ends with a HumanMessage
162
+ if "history" not in state or not isinstance(state["history"], list):
163
+ state["history"] = []
164
+ if not state["history"] or not isinstance(state["history"][-1], HumanMessage):
165
  state["history"].append(HumanMessage(content="Continue."))
166
 
167
  # Ensure context is a dictionary
168
  if not isinstance(state.get("context"), dict):
169
  state["context"] = {}
170
 
171
+ # Ensure reasoning and iterations keys are present
172
+ state.setdefault("reasoning", "")
173
+ state.setdefault("iterations", 0)
174
+
175
+ # Initialize Gemini model via LangChain
176
  llm = ChatGoogleGenerativeAI(
177
  model="gemini-1.5-flash",
178
  temperature=0.1,
179
  google_api_key=GOOGLE_API_KEY
180
  )
181
 
182
+ # Create prompt
183
  prompt = ChatPromptTemplate.from_messages([
184
  ("system", (
185
  "You're an expert problem solver. Analyze the question, select the best tool, "
 
193
  *state["history"]
194
  ])
195
 
196
+ # Invoke model
197
  chain = prompt | llm
198
  response = chain.invoke({
199
  "context": state["context"],
200
+ "reasoning": state["reasoning"],
201
  "question": state["question"]
202
  })
203
 
 
206
 
207
  # Update state
208
  state["history"].append(AIMessage(content=content))
209
+ state["reasoning"] += f"\nStep {state['iterations'] + 1}: {reasoning}"
210
  state["iterations"] += 1
211
 
212
+ # Store either final answer or tool to call
213
  if "final answer" in action.lower():
214
  state["history"].append(AIMessage(content=f"FINAL ANSWER: {action_input}"))
215
  else:
 
221
  return state
222
 
223
 
224
+
225
 
226
 
227
  def tool_node(state: AgentState) -> AgentState:
228
  from langchain.schema import AIMessage
229
 
230
+ # Ensure history exists
231
+ if "history" not in state or not isinstance(state["history"], list):
232
+ raise ValueError("Invalid or missing history in state")
233
+
234
+ # Find the most recent action request in history
235
  tool_call = None
236
  for msg in reversed(state["history"]):
237
  if isinstance(msg, dict) and msg.get("role") == "action_request":
 
241
  if not tool_call:
242
  raise ValueError("No tool call found in history")
243
 
244
+ tool_name = tool_call.get("tool")
245
+ tool_input = tool_call.get("input")
246
+
247
+ # Defensive check for missing tool or input
248
+ if not tool_name or tool_input is None:
249
+ raise ValueError("Tool name or input missing from action request")
250
 
251
  # Look up and invoke the tool
252
+ agent = BasicAgent() # Create agent to access tools
253
+ tool_fn = next((t for t in agent.tools if t.__name__ == tool_name), None)
254
+
255
  if tool_fn is None:
256
  raise ValueError(f"Tool '{tool_name}' not found")
257
 
 
260
  except Exception as e:
261
  tool_output = f"[Tool Error] {str(e)}"
262
 
263
+ # Add output to history as an AIMessage
264
  state["history"].append(AIMessage(content=f"[{tool_name} output]\n{tool_output}"))
265
 
266
  return state
267
 
268
 
269
+
270
  def parse_agent_response(response: str) -> tuple:
271
  """Extract reasoning, action, and input from response"""
272
  reasoning = response.split("Reasoning:")[1].split("Action:")[0].strip()