naman1102 commited on
Commit
2a2fc01
·
1 Parent(s): 59f78f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -36
app.py CHANGED
@@ -126,29 +126,30 @@ def tool_node(state: AgentState) -> AgentState:
126
 
127
  # ─── 4) merge_tool_output ───
128
  def merge_tool_output(state: AgentState) -> AgentState:
129
- """
130
- Combine previous state and tool output into one:
131
- """
132
- prev = state.get("prev_state", {})
133
- merged = {**prev, **state}
134
  merged.pop("prev_state", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return merged
136
 
137
 
138
  # ─── 5) inspect_node ───
139
  def inspect_node(state: AgentState) -> AgentState:
140
- """
141
- After running a tool, show GPT:
142
- - ORIGINAL user question
143
- - Any tool results (web_search_result, ocr_result, excel_result, transcript, wiki_result)
144
- - The INTERIM_ANSWER (what plan_node initially provided under 'final_answer')
145
- Then ask GPT to either:
146
- • Return {"final_answer": "<final>"} if done, OR
147
- • Return exactly one tool key to run next (wiki_query / web_search_query / ocr_path / excel_path & excel_sheet_name / audio_path).
148
- """
149
  messages_for_llm = []
150
 
151
- # 1) Re‐insert original user question
152
  question = ""
153
  for msg in reversed(state.get("messages", [])):
154
  if isinstance(msg, HumanMessage):
@@ -156,7 +157,7 @@ def inspect_node(state: AgentState) -> AgentState:
156
  break
157
  messages_for_llm.append(SystemMessage(content=f"USER_QUESTION: {question}"))
158
 
159
- # 2) Add any tool results
160
  if sr := state.get("web_search_result"):
161
  messages_for_llm.append(SystemMessage(content=f"WEB_SEARCH_RESULT: {sr}"))
162
  if orc := state.get("ocr_result"):
@@ -168,23 +169,57 @@ def inspect_node(state: AgentState) -> AgentState:
168
  if wr := state.get("wiki_result"):
169
  messages_for_llm.append(SystemMessage(content=f"WIKIPEDIA_RESULT: {wr}"))
170
 
171
- # 3) Add the interim answer under INTERIM_ANSWER
172
  if ia := state.get("final_answer"):
173
  messages_for_llm.append(SystemMessage(content=f"INTERIM_ANSWER: {ia}"))
174
 
175
- # 4) Prompt GPT to decide final or another tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  prompt = (
177
  "You have a current draft answer (INTERIM_ANSWER) and possibly some tool results above.\n"
178
- "If you are confident it’s correct, return exactly:\n"
179
  " {\"final_answer\":\"<your final answer>\"}\n"
180
  "and nothing else.\n"
181
- "Otherwise, return exactly one of these JSON literals to fetch another tool:\n"
 
182
  " {\"wiki_query\":\"<query for Wikipedia>\"}\n"
183
  " {\"web_search_query\":\"<search terms>\"}\n"
184
  " {\"ocr_path\":\"<image path or task_id>\"}\n"
185
  " {\"excel_path\":\"<xls path>\", \"excel_sheet_name\":\"<sheet name>\"}\n"
186
  " {\"audio_path\":\"<audio path or task_id>\"}\n"
187
  "Do NOT wrap in markdown—return only the JSON object.\n"
 
188
  )
189
  messages_for_llm.append(SystemMessage(content=prompt))
190
  llm_response = llm(messages_for_llm)
@@ -194,26 +229,34 @@ def inspect_node(state: AgentState) -> AgentState:
194
  try:
195
  parsed = json.loads(raw)
196
  if isinstance(parsed, dict):
197
- partial: AgentState = {"messages": new_msgs}
198
- allowed = {
199
- "final_answer",
200
- "wiki_query",
201
- "web_search_query",
202
- "ocr_path",
203
- "excel_path",
204
- "excel_sheet_name",
205
- "audio_path"
 
 
 
 
 
 
 
206
  }
207
  for k, v in parsed.items():
208
- if k in allowed:
209
- partial[k] = v
210
  return partial
211
  except json.JSONDecodeError:
212
  pass
213
 
214
  return {
215
  "messages": new_msgs,
216
- "final_answer": "ERROR: could not parse inspect decision."
 
 
217
  }
218
 
219
 
@@ -284,9 +327,10 @@ compiled_graph = graph.compile()
284
 
285
 
286
  # ─── 8) respond_to_input ───
287
- def respond_to_input(user_input: str, task_id) -> str:
288
  """
289
  Seed state['messages'] with a SystemMessage + HumanMessage(user_input),
 
290
  then invoke the cyclic graph. Return the final_answer from the resulting state.
291
  """
292
  system_msg = SystemMessage(
@@ -298,19 +342,23 @@ def respond_to_input(user_input: str, task_id) -> str:
298
  " • OCR: set {\"ocr_path\":\"<image path or task_id>\"}\n"
299
  " • Excel: set {\"excel_path\":\"<xlsx path>\", \"excel_sheet_name\":\"<sheet>\"}\n"
300
  " • Audio transcription: set {\"audio_path\":\"<audio path or task_id>\"}\n"
301
- "If you can answer immediately, set {\"final_answer\":\"<answer>\"}. "
302
  "Respond with only one JSON object and no extra formatting."
303
  )
304
  )
305
  human_msg = HumanMessage(content=user_input)
306
 
307
- initial_state: AgentState = {"messages": [system_msg, human_msg], "task_id": task_id}
 
 
 
 
 
308
  final_state = compiled_graph.invoke(initial_state)
309
  return final_state.get("final_answer", "Error: No final answer generated.")
310
 
311
 
312
 
313
-
314
  class BasicAgent:
315
  def __init__(self):
316
  print("BasicAgent initialized.")
 
126
 
127
  # ─── 4) merge_tool_output ───
128
  def merge_tool_output(state: AgentState) -> AgentState:
129
+ prev_state = state.get("prev_state", {})
130
+ merged = {**prev_state, **state}
 
 
 
131
  merged.pop("prev_state", None)
132
+
133
+ # Detect which tool key was used in prev_state (it’s exactly one of these)
134
+ for tool_key in ("wiki_query", "web_search_query", "ocr_path", "excel_path", "audio_path"):
135
+ if prev_state.get(tool_key) is not None:
136
+ # Increment the count of tool calls
137
+ merged["tool_calls"] = merged.get("tool_calls", 0) + 1
138
+ # Record that we have used this tool_key
139
+ used = merged.get("used_tools", []).copy()
140
+ if tool_key not in used:
141
+ used.append(tool_key)
142
+ merged["used_tools"] = used
143
+ break
144
+
145
  return merged
146
 
147
 
148
  # ─── 5) inspect_node ───
149
  def inspect_node(state: AgentState) -> AgentState:
 
 
 
 
 
 
 
 
 
150
  messages_for_llm = []
151
 
152
+ # 1) Original question
153
  question = ""
154
  for msg in reversed(state.get("messages", [])):
155
  if isinstance(msg, HumanMessage):
 
157
  break
158
  messages_for_llm.append(SystemMessage(content=f"USER_QUESTION: {question}"))
159
 
160
+ # 2) Any tool results so far
161
  if sr := state.get("web_search_result"):
162
  messages_for_llm.append(SystemMessage(content=f"WEB_SEARCH_RESULT: {sr}"))
163
  if orc := state.get("ocr_result"):
 
169
  if wr := state.get("wiki_result"):
170
  messages_for_llm.append(SystemMessage(content=f"WIKIPEDIA_RESULT: {wr}"))
171
 
172
+ # 3) Interim answer
173
  if ia := state.get("final_answer"):
174
  messages_for_llm.append(SystemMessage(content=f"INTERIM_ANSWER: {ia}"))
175
 
176
+ # 4) How many times have we called a tool?
177
+ used_tools = state.get("used_tools", [])
178
+ tool_calls = state.get("tool_calls", 0)
179
+
180
+ # If we've already tried all five tools once, or exceeded a small limit (e.g. 5),
181
+ # force a final answer now. We append a dummy instruction so the LLM knows:
182
+ if tool_calls >= 5 or len(used_tools) >= 5:
183
+ # The user’s interim answer and all tool results exist;
184
+ # we instruct GPT to treat it as final.
185
+ prompt = (
186
+ "We have already used every available tool or reached the maximum number of attempts.\n"
187
+ "Therefore, return exactly {\"final_answer\":\"<your best final answer>\"} and nothing else.\n"
188
+ )
189
+ messages_for_llm.append(SystemMessage(content=prompt))
190
+ llm_response = llm(messages_for_llm)
191
+ raw = llm_response.content.strip()
192
+ new_msgs = state["messages"] + [AIMessage(content=raw)]
193
+ try:
194
+ parsed = json.loads(raw)
195
+ if isinstance(parsed, dict) and "final_answer" in parsed:
196
+ return {"messages": new_msgs, "final_answer": parsed["final_answer"],
197
+ "used_tools": used_tools, "tool_calls": tool_calls}
198
+ except json.JSONDecodeError:
199
+ pass
200
+ # Fallback
201
+ return {
202
+ "messages": new_msgs,
203
+ "final_answer": "ERROR: inspect forced final but parsing failed.",
204
+ "used_tools": used_tools,
205
+ "tool_calls": tool_calls
206
+ }
207
+
208
+ # 5) Otherwise, ask GPT if it wants another tool
209
  prompt = (
210
  "You have a current draft answer (INTERIM_ANSWER) and possibly some tool results above.\n"
211
+ "If you are confident it’s now correct, return exactly:\n"
212
  " {\"final_answer\":\"<your final answer>\"}\n"
213
  "and nothing else.\n"
214
+ "Otherwise, return exactly one of these JSON literals to fetch another tool, "
215
+ "but DO NOT return a tool you have already used:\n"
216
  " {\"wiki_query\":\"<query for Wikipedia>\"}\n"
217
  " {\"web_search_query\":\"<search terms>\"}\n"
218
  " {\"ocr_path\":\"<image path or task_id>\"}\n"
219
  " {\"excel_path\":\"<xls path>\", \"excel_sheet_name\":\"<sheet name>\"}\n"
220
  " {\"audio_path\":\"<audio path or task_id>\"}\n"
221
  "Do NOT wrap in markdown—return only the JSON object.\n"
222
+ f"Already used tools: {used_tools}\n"
223
  )
224
  messages_for_llm.append(SystemMessage(content=prompt))
225
  llm_response = llm(messages_for_llm)
 
229
  try:
230
  parsed = json.loads(raw)
231
  if isinstance(parsed, dict):
232
+ # If GPT asks for a tool that’s already in used_tools, override and force finalize.
233
+ for key in parsed:
234
+ if key in ("wiki_query", "web_search_query", "ocr_path", "excel_path", "audio_path"):
235
+ if key in used_tools:
236
+ # GPT tried to reuse a tool → force a final answer instead
237
+ return {
238
+ "messages": new_msgs,
239
+ "final_answer": state.get("final_answer", ""),
240
+ "used_tools": used_tools,
241
+ "tool_calls": tool_calls
242
+ }
243
+ # Otherwise, it’s either final_answer or a brand‐new tool request
244
+ partial: AgentState = {
245
+ "messages": new_msgs,
246
+ "used_tools": used_tools,
247
+ "tool_calls": tool_calls
248
  }
249
  for k, v in parsed.items():
250
+ partial[k] = v
 
251
  return partial
252
  except json.JSONDecodeError:
253
  pass
254
 
255
  return {
256
  "messages": new_msgs,
257
+ "final_answer": "ERROR: could not parse inspect decision.",
258
+ "used_tools": used_tools,
259
+ "tool_calls": tool_calls
260
  }
261
 
262
 
 
327
 
328
 
329
  # ─── 8) respond_to_input ───
330
+ def respond_to_input(user_input: str, task_id: str) -> str:
331
  """
332
  Seed state['messages'] with a SystemMessage + HumanMessage(user_input),
333
+ include the current task_id so that OCR/Audio tools can fetch files,
334
  then invoke the cyclic graph. Return the final_answer from the resulting state.
335
  """
336
  system_msg = SystemMessage(
 
342
  " • OCR: set {\"ocr_path\":\"<image path or task_id>\"}\n"
343
  " • Excel: set {\"excel_path\":\"<xlsx path>\", \"excel_sheet_name\":\"<sheet>\"}\n"
344
  " • Audio transcription: set {\"audio_path\":\"<audio path or task_id>\"}\n"
345
+ "If you can answer immediately, set {\"final_answer\":\"<answer>\"}.\n"
346
  "Respond with only one JSON object and no extra formatting."
347
  )
348
  )
349
  human_msg = HumanMessage(content=user_input)
350
 
351
+ initial_state: AgentState = {
352
+ "messages": [system_msg, human_msg],
353
+ "task_id": task_id,
354
+ "used_tools": [], # track which tools have been requested
355
+ "tool_calls": 0 # count of how many tool invocations so far
356
+ }
357
  final_state = compiled_graph.invoke(initial_state)
358
  return final_state.get("final_answer", "Error: No final answer generated.")
359
 
360
 
361
 
 
362
  class BasicAgent:
363
  def __init__(self):
364
  print("BasicAgent initialized.")