naman1102 commited on
Commit
e168d85
·
1 Parent(s): d849921

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -40
app.py CHANGED
@@ -22,79 +22,118 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
23
 
24
  class AgentState(TypedDict):
 
25
  messages: Annotated[list[str], add_messages]
26
- tool: str # will store the name of the requested tool (if any)
27
- agent_out: str # raw output from the LLM
 
 
28
 
29
- # 2) Instantiate the raw LLM and wrap it in a function
30
  llm = ChatOpenAI(model_name="gpt-4.1-mini")
31
 
32
  def agent_node(state: AgentState, user_input: str) -> AgentState:
33
- prev_msgs = state.get("messages", [])
34
- messages = prev_msgs + [f"USER: {user_input}"]
35
- # Ask the LLM for a response
36
- llm_response = llm(messages).content # returns a string or maybe JSON string
37
- # If you expect JSON with {"tool": "...", ...}, parse it:
38
- tool_requested = None
 
 
 
 
 
 
 
 
39
  try:
40
- parsed = eval(llm_response) # (use json.loads if the LLM returns valid JSON)
41
  if isinstance(parsed, dict) and parsed.get("tool"):
42
- tool_requested = parsed.get("tool")
43
- except:
44
- pass
45
 
 
46
  return {
47
- "messages": messages + [f"ASSISTANT: {llm_response}"],
48
- "agent_out": llm_response,
49
- "tool": tool_requested or ""
50
  }
51
 
52
- # 3) Instantiate a real ToolNode for your three tools
53
- t_node = ToolNode([ocr_image, parse_excel, web_search])
 
 
 
 
 
 
 
 
 
 
54
 
55
- def run_tool_node(state: AgentState, agent_output) -> AgentState:
56
- # `agent_output` is the dict that the LLM returned, e.g. {"tool":"ocr_image", "path": "file.png"}
57
- tool_result: str = t_node.run(agent_output)
58
  return {
59
- "messages": [f"TOOL RESULT: {tool_result}"],
60
- "tool": "", # once a tool has run, clear this so we don’t loop forever
61
- "agent_out": tool_result
62
  }
63
 
64
- # 4) Build the StateGraph with the corrected node names
65
  graph = StateGraph(AgentState)
66
  graph.add_node("agent", agent_node)
67
- graph.add_node("tools", run_tool_node)
68
 
69
- # 5) START → "agent"
70
  graph.add_edge(START, "agent")
71
 
72
- # 6) "tools""agent"
73
  graph.add_edge("tools", "agent")
74
 
75
- # 7) Conditional edges out of "agent"
76
- def route_agent(state: AgentState, agent_output):
77
- # If LLM asked for a tool, we go to "tools"; else we terminate
78
- if isinstance(agent_output, dict) and agent_output.get("tool") in {"ocr_image", "parse_excel", "web_search"}:
 
 
 
 
 
 
 
 
 
79
  return "tools"
80
  return "final"
81
 
82
  graph.add_conditional_edges(
83
- "agent",
84
- route_agent,
85
  {
86
- "tools": "tools",
87
- "final": END
88
  }
89
  )
90
 
91
- # 8) Compile the graph and use run(), not invoke(…)
92
  compiled_graph = graph.compile()
93
 
 
94
  def respond_to_input(user_input: str) -> str:
95
- initial_state: AgentState = {"messages": [], "tool": "", "agent_out": ""}
96
- # Use .run() in v0.3.x; if you see an AttributeError, switch to .invoke()
97
- return compiled_graph.invoke(initial_state, user_input)
 
 
 
 
 
 
 
 
98
 
99
 
100
  class BasicAgent:
 
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
23
 
24
  class AgentState(TypedDict):
25
+ # We store the full chat history as a list of strings.
26
  messages: Annotated[list[str], add_messages]
27
+ # If the agent requests a tool, it will fill in:
28
+ tool_request: dict | None
29
+ # Whenever a tool runs, its result goes here:
30
+ tool_result: str | None
31
 
32
+ # 2) Wrap ChatOpenAI in a function whose signature is (state, user_input) → new_state
33
  llm = ChatOpenAI(model_name="gpt-4.1-mini")
34
 
35
  def agent_node(state: AgentState, user_input: str) -> AgentState:
36
+ """
37
+ This function replaces raw ChatOpenAI. It must accept (state, user_input)
38
+ and return a new AgentState dict.
39
+ """
40
+ # 2.a) Grab prior chat history (empty list on first turn)
41
+ prior_msgs = state.get("messages", [])
42
+ # 2.b) Append the new user_input
43
+ chat_history = prior_msgs + [f"USER: {user_input}"]
44
+ # 2.c) Ask the LLM for a response
45
+ llm_output = llm(chat_history).content
46
+
47
+ # 2.d) Check if the LLM output is valid Python dict literal indicating a tool call.
48
+ # If it is, parse it and stash in state["tool_request"]. Otherwise, no tool.
49
+ tool_req = None
50
  try:
51
+ parsed = eval(llm_output)
52
  if isinstance(parsed, dict) and parsed.get("tool"):
53
+ tool_req = parsed
54
+ except Exception:
55
+ tool_req = None
56
 
57
+ # 2.e) Construct the new state:
58
  return {
59
+ "messages": chat_history + [f"ASSISTANT: {llm_output}"],
60
+ "tool_request": tool_req,
61
+ "tool_result": None # will be filled by the tool_node if invoked
62
  }
63
 
64
+ # 3) Create a ToolNode for all three tools, then wrap it in a function
65
+ # whose signature is also (state, tool_request) → new_state.
66
+ underlying_tool_node = ToolNode([ocr_image, parse_excel, web_search])
67
+
68
+ def tool_node(state: AgentState, tool_request: dict) -> AgentState:
69
+ """
70
+ The graph will only call this when tool_request is a dict like
71
+ {"tool": "...", "path": "...", ...}
72
+ Use the underlying ToolNode to run it and store the result.
73
+ """
74
+ # 3.a) Run the actual ToolNode on that dict:
75
+ result_text = underlying_tool_node.run(tool_request)
76
 
77
+ # 3.b) Update state.messages to note the tool’s output,
78
+ # and clear tool_request so we don’t loop.
 
79
  return {
80
+ "messages": [f"TOOL ({tool_request['tool']}): {result_text}"],
81
+ "tool_request": None,
82
+ "tool_result": result_text
83
  }
84
 
85
+ # 4) Build and register nodes exactly as in the tutorial
86
  graph = StateGraph(AgentState)
87
  graph.add_node("agent", agent_node)
88
+ graph.add_node("tools", tool_node)
89
 
90
+ # 5) Simple START → agent” edge (no third argument needed)
91
  graph.add_edge(START, "agent")
92
 
93
+ # 6) Simple “toolsagent” edge (again, no third argument)
94
  graph.add_edge("tools", "agent")
95
 
96
+ # 7) Conditional branching out of agent,” exactly like the tutorial
97
+ def route_agent(state: AgentState, agent_out):
98
+ """
99
+ When the LLM (agent_node) runs, it returns an AgentState where
100
+ - state["tool_request"] is either a dict (if a tool was asked) or None.
101
+ - state["tool_result"] is always None on entry to agent_node.
102
+
103
+ route_agent must look at that returned state (called agent_out)
104
+ and decide:
105
+ • If agent_out["tool_request"] is not None, go to "tools".
106
+ • Otherwise, terminate (go to END).
107
+ """
108
+ if agent_out.get("tool_request") is not None:
109
  return "tools"
110
  return "final"
111
 
112
  graph.add_conditional_edges(
113
+ "agent", # source
114
+ route_agent, # routing function (signature: (state, agent_out) → str key)
115
  {
116
+ "tools": "tools", # if route_agent(...) == "tools", transition to node "tools"
117
+ "final": END # if route_agent(...) == "final", stop execution
118
  }
119
  )
120
 
121
+ # 8) Compile the graph (now graph.run(...) will work)
122
  compiled_graph = graph.compile()
123
 
124
+ # 9) Define respond_to_input so that Gradio (and the Hugging Face submission) can call it
125
  def respond_to_input(user_input: str) -> str:
126
+ # Start with an empty state
127
+ initial_state: AgentState = {
128
+ "messages": [],
129
+ "tool_request": None,
130
+ "tool_result": None
131
+ }
132
+ # Use .run(initial_state, user_input) in v0.3.x
133
+ final_state = compiled_graph.run(initial_state, user_input)
134
+ # The “final” on END means agent_out has no more tool calls and finished reasoning
135
+ # We return the last assistant message from state["messages"]:
136
+ return final_state["messages"][-1].replace("ASSISTANT: ", "")
137
 
138
 
139
  class BasicAgent: