Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -126,29 +126,30 @@ def tool_node(state: AgentState) -> AgentState:
|
|
126 |
|
127 |
# ─── 4) merge_tool_output ───
|
128 |
def merge_tool_output(state: AgentState) -> AgentState:
|
129 |
-
""
|
130 |
-
|
131 |
-
"""
|
132 |
-
prev = state.get("prev_state", {})
|
133 |
-
merged = {**prev, **state}
|
134 |
merged.pop("prev_state", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
return merged
|
136 |
|
137 |
|
138 |
# ─── 5) inspect_node ───
|
139 |
def inspect_node(state: AgentState) -> AgentState:
|
140 |
-
"""
|
141 |
-
After running a tool, show GPT:
|
142 |
-
- ORIGINAL user question
|
143 |
-
- Any tool results (web_search_result, ocr_result, excel_result, transcript, wiki_result)
|
144 |
-
- The INTERIM_ANSWER (what plan_node initially provided under 'final_answer')
|
145 |
-
Then ask GPT to either:
|
146 |
-
• Return {"final_answer": "<final>"} if done, OR
|
147 |
-
• Return exactly one tool key to run next (wiki_query / web_search_query / ocr_path / excel_path & excel_sheet_name / audio_path).
|
148 |
-
"""
|
149 |
messages_for_llm = []
|
150 |
|
151 |
-
# 1)
|
152 |
question = ""
|
153 |
for msg in reversed(state.get("messages", [])):
|
154 |
if isinstance(msg, HumanMessage):
|
@@ -156,7 +157,7 @@ def inspect_node(state: AgentState) -> AgentState:
|
|
156 |
break
|
157 |
messages_for_llm.append(SystemMessage(content=f"USER_QUESTION: {question}"))
|
158 |
|
159 |
-
# 2)
|
160 |
if sr := state.get("web_search_result"):
|
161 |
messages_for_llm.append(SystemMessage(content=f"WEB_SEARCH_RESULT: {sr}"))
|
162 |
if orc := state.get("ocr_result"):
|
@@ -168,23 +169,57 @@ def inspect_node(state: AgentState) -> AgentState:
|
|
168 |
if wr := state.get("wiki_result"):
|
169 |
messages_for_llm.append(SystemMessage(content=f"WIKIPEDIA_RESULT: {wr}"))
|
170 |
|
171 |
-
# 3)
|
172 |
if ia := state.get("final_answer"):
|
173 |
messages_for_llm.append(SystemMessage(content=f"INTERIM_ANSWER: {ia}"))
|
174 |
|
175 |
-
# 4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
prompt = (
|
177 |
"You have a current draft answer (INTERIM_ANSWER) and possibly some tool results above.\n"
|
178 |
-
"If you are confident it’s correct, return exactly:\n"
|
179 |
" {\"final_answer\":\"<your final answer>\"}\n"
|
180 |
"and nothing else.\n"
|
181 |
-
"Otherwise, return exactly one of these JSON literals to fetch another tool
|
|
|
182 |
" {\"wiki_query\":\"<query for Wikipedia>\"}\n"
|
183 |
" {\"web_search_query\":\"<search terms>\"}\n"
|
184 |
" {\"ocr_path\":\"<image path or task_id>\"}\n"
|
185 |
" {\"excel_path\":\"<xls path>\", \"excel_sheet_name\":\"<sheet name>\"}\n"
|
186 |
" {\"audio_path\":\"<audio path or task_id>\"}\n"
|
187 |
"Do NOT wrap in markdown—return only the JSON object.\n"
|
|
|
188 |
)
|
189 |
messages_for_llm.append(SystemMessage(content=prompt))
|
190 |
llm_response = llm(messages_for_llm)
|
@@ -194,26 +229,34 @@ def inspect_node(state: AgentState) -> AgentState:
|
|
194 |
try:
|
195 |
parsed = json.loads(raw)
|
196 |
if isinstance(parsed, dict):
|
197 |
-
|
198 |
-
|
199 |
-
"
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
}
|
207 |
for k, v in parsed.items():
|
208 |
-
|
209 |
-
partial[k] = v
|
210 |
return partial
|
211 |
except json.JSONDecodeError:
|
212 |
pass
|
213 |
|
214 |
return {
|
215 |
"messages": new_msgs,
|
216 |
-
"final_answer": "ERROR: could not parse inspect decision."
|
|
|
|
|
217 |
}
|
218 |
|
219 |
|
@@ -284,9 +327,10 @@ compiled_graph = graph.compile()
|
|
284 |
|
285 |
|
286 |
# ─── 8) respond_to_input ───
|
287 |
-
def respond_to_input(user_input: str, task_id) -> str:
|
288 |
"""
|
289 |
Seed state['messages'] with a SystemMessage + HumanMessage(user_input),
|
|
|
290 |
then invoke the cyclic graph. Return the final_answer from the resulting state.
|
291 |
"""
|
292 |
system_msg = SystemMessage(
|
@@ -298,19 +342,23 @@ def respond_to_input(user_input: str, task_id) -> str:
|
|
298 |
" • OCR: set {\"ocr_path\":\"<image path or task_id>\"}\n"
|
299 |
" • Excel: set {\"excel_path\":\"<xlsx path>\", \"excel_sheet_name\":\"<sheet>\"}\n"
|
300 |
" • Audio transcription: set {\"audio_path\":\"<audio path or task_id>\"}\n"
|
301 |
-
"If you can answer immediately, set {\"final_answer\":\"<answer>\"}
|
302 |
"Respond with only one JSON object and no extra formatting."
|
303 |
)
|
304 |
)
|
305 |
human_msg = HumanMessage(content=user_input)
|
306 |
|
307 |
-
initial_state: AgentState = {
|
|
|
|
|
|
|
|
|
|
|
308 |
final_state = compiled_graph.invoke(initial_state)
|
309 |
return final_state.get("final_answer", "Error: No final answer generated.")
|
310 |
|
311 |
|
312 |
|
313 |
-
|
314 |
class BasicAgent:
|
315 |
def __init__(self):
|
316 |
print("BasicAgent initialized.")
|
|
|
126 |
|
127 |
# ─── 4) merge_tool_output ───
|
128 |
def merge_tool_output(state: AgentState) -> AgentState:
|
129 |
+
prev_state = state.get("prev_state", {})
|
130 |
+
merged = {**prev_state, **state}
|
|
|
|
|
|
|
131 |
merged.pop("prev_state", None)
|
132 |
+
|
133 |
+
# Detect which tool key was used in prev_state (it’s exactly one of these)
|
134 |
+
for tool_key in ("wiki_query", "web_search_query", "ocr_path", "excel_path", "audio_path"):
|
135 |
+
if prev_state.get(tool_key) is not None:
|
136 |
+
# Increment the count of tool calls
|
137 |
+
merged["tool_calls"] = merged.get("tool_calls", 0) + 1
|
138 |
+
# Record that we have used this tool_key
|
139 |
+
used = merged.get("used_tools", []).copy()
|
140 |
+
if tool_key not in used:
|
141 |
+
used.append(tool_key)
|
142 |
+
merged["used_tools"] = used
|
143 |
+
break
|
144 |
+
|
145 |
return merged
|
146 |
|
147 |
|
148 |
# ─── 5) inspect_node ───
|
149 |
def inspect_node(state: AgentState) -> AgentState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
messages_for_llm = []
|
151 |
|
152 |
+
# 1) Original question
|
153 |
question = ""
|
154 |
for msg in reversed(state.get("messages", [])):
|
155 |
if isinstance(msg, HumanMessage):
|
|
|
157 |
break
|
158 |
messages_for_llm.append(SystemMessage(content=f"USER_QUESTION: {question}"))
|
159 |
|
160 |
+
# 2) Any tool results so far
|
161 |
if sr := state.get("web_search_result"):
|
162 |
messages_for_llm.append(SystemMessage(content=f"WEB_SEARCH_RESULT: {sr}"))
|
163 |
if orc := state.get("ocr_result"):
|
|
|
169 |
if wr := state.get("wiki_result"):
|
170 |
messages_for_llm.append(SystemMessage(content=f"WIKIPEDIA_RESULT: {wr}"))
|
171 |
|
172 |
+
# 3) Interim answer
|
173 |
if ia := state.get("final_answer"):
|
174 |
messages_for_llm.append(SystemMessage(content=f"INTERIM_ANSWER: {ia}"))
|
175 |
|
176 |
+
# 4) How many times have we called a tool?
|
177 |
+
used_tools = state.get("used_tools", [])
|
178 |
+
tool_calls = state.get("tool_calls", 0)
|
179 |
+
|
180 |
+
# If we've already tried all five tools once, or exceeded a small limit (e.g. 5),
|
181 |
+
# force a final answer now. We append a dummy instruction so the LLM knows:
|
182 |
+
if tool_calls >= 5 or len(used_tools) >= 5:
|
183 |
+
# The user’s interim answer and all tool results exist;
|
184 |
+
# we instruct GPT to treat it as final.
|
185 |
+
prompt = (
|
186 |
+
"We have already used every available tool or reached the maximum number of attempts.\n"
|
187 |
+
"Therefore, return exactly {\"final_answer\":\"<your best final answer>\"} and nothing else.\n"
|
188 |
+
)
|
189 |
+
messages_for_llm.append(SystemMessage(content=prompt))
|
190 |
+
llm_response = llm(messages_for_llm)
|
191 |
+
raw = llm_response.content.strip()
|
192 |
+
new_msgs = state["messages"] + [AIMessage(content=raw)]
|
193 |
+
try:
|
194 |
+
parsed = json.loads(raw)
|
195 |
+
if isinstance(parsed, dict) and "final_answer" in parsed:
|
196 |
+
return {"messages": new_msgs, "final_answer": parsed["final_answer"],
|
197 |
+
"used_tools": used_tools, "tool_calls": tool_calls}
|
198 |
+
except json.JSONDecodeError:
|
199 |
+
pass
|
200 |
+
# Fallback
|
201 |
+
return {
|
202 |
+
"messages": new_msgs,
|
203 |
+
"final_answer": "ERROR: inspect forced final but parsing failed.",
|
204 |
+
"used_tools": used_tools,
|
205 |
+
"tool_calls": tool_calls
|
206 |
+
}
|
207 |
+
|
208 |
+
# 5) Otherwise, ask GPT if it wants another tool
|
209 |
prompt = (
|
210 |
"You have a current draft answer (INTERIM_ANSWER) and possibly some tool results above.\n"
|
211 |
+
"If you are confident it’s now correct, return exactly:\n"
|
212 |
" {\"final_answer\":\"<your final answer>\"}\n"
|
213 |
"and nothing else.\n"
|
214 |
+
"Otherwise, return exactly one of these JSON literals to fetch another tool, "
|
215 |
+
"but DO NOT return a tool you have already used:\n"
|
216 |
" {\"wiki_query\":\"<query for Wikipedia>\"}\n"
|
217 |
" {\"web_search_query\":\"<search terms>\"}\n"
|
218 |
" {\"ocr_path\":\"<image path or task_id>\"}\n"
|
219 |
" {\"excel_path\":\"<xls path>\", \"excel_sheet_name\":\"<sheet name>\"}\n"
|
220 |
" {\"audio_path\":\"<audio path or task_id>\"}\n"
|
221 |
"Do NOT wrap in markdown—return only the JSON object.\n"
|
222 |
+
f"Already used tools: {used_tools}\n"
|
223 |
)
|
224 |
messages_for_llm.append(SystemMessage(content=prompt))
|
225 |
llm_response = llm(messages_for_llm)
|
|
|
229 |
try:
|
230 |
parsed = json.loads(raw)
|
231 |
if isinstance(parsed, dict):
|
232 |
+
# If GPT asks for a tool that’s already in used_tools, override and force finalize.
|
233 |
+
for key in parsed:
|
234 |
+
if key in ("wiki_query", "web_search_query", "ocr_path", "excel_path", "audio_path"):
|
235 |
+
if key in used_tools:
|
236 |
+
# GPT tried to reuse a tool → force a final answer instead
|
237 |
+
return {
|
238 |
+
"messages": new_msgs,
|
239 |
+
"final_answer": state.get("final_answer", ""),
|
240 |
+
"used_tools": used_tools,
|
241 |
+
"tool_calls": tool_calls
|
242 |
+
}
|
243 |
+
# Otherwise, it’s either final_answer or a brand‐new tool request
|
244 |
+
partial: AgentState = {
|
245 |
+
"messages": new_msgs,
|
246 |
+
"used_tools": used_tools,
|
247 |
+
"tool_calls": tool_calls
|
248 |
}
|
249 |
for k, v in parsed.items():
|
250 |
+
partial[k] = v
|
|
|
251 |
return partial
|
252 |
except json.JSONDecodeError:
|
253 |
pass
|
254 |
|
255 |
return {
|
256 |
"messages": new_msgs,
|
257 |
+
"final_answer": "ERROR: could not parse inspect decision.",
|
258 |
+
"used_tools": used_tools,
|
259 |
+
"tool_calls": tool_calls
|
260 |
}
|
261 |
|
262 |
|
|
|
327 |
|
328 |
|
329 |
# ─── 8) respond_to_input ───
|
330 |
+
def respond_to_input(user_input: str, task_id: str) -> str:
|
331 |
"""
|
332 |
Seed state['messages'] with a SystemMessage + HumanMessage(user_input),
|
333 |
+
include the current task_id so that OCR/Audio tools can fetch files,
|
334 |
then invoke the cyclic graph. Return the final_answer from the resulting state.
|
335 |
"""
|
336 |
system_msg = SystemMessage(
|
|
|
342 |
" • OCR: set {\"ocr_path\":\"<image path or task_id>\"}\n"
|
343 |
" • Excel: set {\"excel_path\":\"<xlsx path>\", \"excel_sheet_name\":\"<sheet>\"}\n"
|
344 |
" • Audio transcription: set {\"audio_path\":\"<audio path or task_id>\"}\n"
|
345 |
+
"If you can answer immediately, set {\"final_answer\":\"<answer>\"}.\n"
|
346 |
"Respond with only one JSON object and no extra formatting."
|
347 |
)
|
348 |
)
|
349 |
human_msg = HumanMessage(content=user_input)
|
350 |
|
351 |
+
initial_state: AgentState = {
|
352 |
+
"messages": [system_msg, human_msg],
|
353 |
+
"task_id": task_id,
|
354 |
+
"used_tools": [], # track which tools have been requested
|
355 |
+
"tool_calls": 0 # count of how many tool invocations so far
|
356 |
+
}
|
357 |
final_state = compiled_graph.invoke(initial_state)
|
358 |
return final_state.get("final_answer", "Error: No final answer generated.")
|
359 |
|
360 |
|
361 |
|
|
|
362 |
class BasicAgent:
|
363 |
def __init__(self):
|
364 |
print("BasicAgent initialized.")
|