naman1102 commited on
Commit
0e29657
·
1 Parent(s): 4eea303
Files changed (2) hide show
  1. app.py +147 -101
  2. tools.py +57 -64
app.py CHANGED
@@ -16,10 +16,21 @@ from langchain.schema import HumanMessage, AIMessage, SystemMessage
16
  # Create a ToolNode that knows about your web_search function
17
  import json
18
 
19
- # (Keep Constan
20
- #
21
- #
22
- # ts as is)
 
 
 
 
 
 
 
 
 
 
 
23
  # --- Constants ---
24
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
25
 
@@ -29,108 +40,143 @@ tool_node = ToolNode([ocr_image, parse_excel, web_search])
29
  agent = create_react_agent(model=llm, tools=tool_node)
30
 
31
  # 2) Build a two‐edge graph:
32
- graph = StateGraph(dict)
33
- graph.add_node("agent", agent)
34
- graph.add_edge(START, "agent")
35
- graph.add_edge("agent", END)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  compiled_graph = graph.compile()
37
 
 
38
  def respond_to_input(user_input: str) -> str:
39
- # 1) Build a SystemMessage that insists on bare JSON if calling a tool
40
- system_msg = SystemMessage(
41
- content=(
42
- "You are an assistant with access to exactly these tools:\n"
43
- " 1) web_search(query:str)\n"
44
- " 2) parse_excel(path:str,sheet_name:str)\n"
45
- " 3) ocr_image(path:str)\n\n"
46
- "⚠️ **MANDATORY** ⚠️: If (and only if) you need to call a tool, your entire response MUST be exactly ONE JSON OBJECT and NOTHING ELSE. \n"
47
- "For example, if you want to call web_search, you must respond with exactly:\n"
48
- "```json\n"
49
- '{"tool":"web_search","query":"Mercedes Sosa studio albums 2000-2009"}\n'
50
- "```\n"
51
- "That JSON string must start at the very first character of your response and end at the very last character—"
52
- "no surrounding quotes, no markdown fences, no explanatory text. \n\n"
53
- "If you do NOT need to call any tool, then you must respond with your final answer as plain text (no JSON)."
54
- )
55
- )
56
 
57
- # 2) Initialize state with just that SystemMessage
58
- initial_state = {
59
- "messages": [
60
- system_msg,
61
- HumanMessage(content=user_input)
62
- ]
63
- }
64
 
65
- # C) FIRST PASS: invoke with only initial_state (no second argument!)
66
- try:
67
- first_pass = compiled_graph.invoke(initial_state)
68
- except Exception as e:
69
- print("‼️ ERROR during first invoke:", repr(e))
70
- return "" # return fallback
71
-
72
- # D) Log the AIMessage(s) from first_pass
73
- print("===== AGENT MESSAGES (First Pass) =====")
74
- for idx, msg in enumerate(first_pass["messages"]):
75
- if isinstance(msg, AIMessage):
76
- print(f"[AIMessage #{idx}]: {repr(msg.content)}")
77
- print("=========================================")
78
-
79
- # E) Find the very last AIMessage content
80
- last_msg = None
81
- for msg in reversed(first_pass["messages"]):
82
- if isinstance(msg, AIMessage):
83
- last_msg = msg.content
84
- break
85
-
86
- # F) Attempt to parse last_msg as JSON for a tool call (inline, no parse_tool_json)
87
- tool_dict = None
88
- t = (last_msg or "").strip()
89
- if (t.startswith('"') and t.endswith('"')) or (t.startswith("'") and t.endswith("'")):
90
- t = t[1:-1]
91
- try:
92
- obj = json.loads(t)
93
- if isinstance(obj, dict) and "tool" in obj:
94
- tool_dict = obj
95
- except Exception:
96
- tool_dict = None
97
-
98
- if tool_dict:
99
- # G) If valid JSON, run the tool
100
- print(">> Parsed tool call:", tool_dict)
101
- tool_result = tool_node.run(tool_dict)
102
- print(f">> Tool '{tool_dict['tool']}' returned: {repr(tool_result)}")
103
-
104
- # H) SECOND PASS: feed the tool's output back in as an AIMessage,
105
- # with no new human input
106
- continuation_state = {
107
- "messages": [
108
- *first_pass["messages"],
109
- AIMessage(content=tool_result)
110
- ]
111
- }
112
- try:
113
- second_pass = compiled_graph.invoke(continuation_state)
114
- except Exception as e2:
115
- print("‼️ ERROR during second invoke:", repr(e2))
116
- return ""
117
-
118
- # I) Log second_pass AIMessage(s)
119
- print("===== AGENT MESSAGES (Second Pass) =====")
120
- for idx, msg in enumerate(second_pass["messages"]):
121
- if isinstance(msg, AIMessage):
122
- print(f"[AIMessage2 #{idx}]: {repr(msg.content)}")
123
- print("=========================================")
124
-
125
- # J) Return the final AIMessage from second_pass
126
- for msg in reversed(second_pass["messages"]):
127
- if isinstance(msg, AIMessage):
128
- return msg.content or ""
129
- return ""
130
-
131
- else:
132
- # K) If not JSON → treat last_msg as plain text final answer
133
- return last_msg or ""
134
 
135
  class BasicAgent:
136
  def __init__(self):
 
16
  # Create a ToolNode that knows about your web_search function
17
  import json
18
 
19
+ from typing import TypedDict, Annotated
20
+
21
+ class AgentState(TypedDict, total=False):
22
+ messages: Annotated[list, add_messages]
23
+ # Fields that the agent node can set to request a tool
24
+ web_search_query: str
25
+ ocr_path: str
26
+ excel_path: str
27
+ excel_sheet_name: str
28
+ # Fields to hold the tool outputs
29
+ web_search_result: str
30
+ ocr_result: str
31
+ excel_result: str
32
+ # A “final_answer” field that the last agent node will fill
33
+ final_answer: str# (Keep Constants as is)
34
  # --- Constants ---
35
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
36
 
 
40
  agent = create_react_agent(model=llm, tools=tool_node)
41
 
42
  # 2) Build a two‐edge graph:
43
+ def plan_node(state: AgentState, user_input: str) -> AgentState:
44
+ """
45
+ Reads state['messages'] + user_input and decides:
46
+ • If it needs to call web_search, set state['web_search_query'] to a query.
47
+ • Else if it needs to call ocr, set state['ocr_path'] to the image path.
48
+ • Else if it needs Excel, set state['excel_path'] and 'excel_sheet_name'.
49
+ • Otherwise, set state['final_answer'] to a plain text answer.
50
+ We also append user_input to state['messages'] so the LLM sees the full history.
51
+ """
52
+ # 4.a) Grab prior chat history, append user_input:
53
+ prior = state.get("messages", [])
54
+ chat_history = prior + [f"USER: {user_input}"]
55
+
56
+ # 4.b) Send that to the LLM with a prompt explaining the new schema:
57
+ prompt = chat_history + [
58
+ "ASSISTANT: You can set one of the following keys:\n"
59
+ " • web_search_query: <string> \n"
60
+ " • ocr_path: <path> \n"
61
+ " • excel_path: <path> \n"
62
+ " • excel_sheet_name: <sheet> \n"
63
+ "Or, if no tool is needed, set final_answer: <your answer>.\n"
64
+ "Respond with a Python‐dict literal that contains exactly one of those keys.\n"
65
+ "Example: {'web_search_query':'Mercedes Sosa discography'}\n"
66
+ "No additional text!"
67
+ ]
68
+ llm_out = llm(prompt).content.strip()
69
+
70
+ # 4.c) Try to eval as a Python dict:
71
+ try:
72
+ parsed = eval(llm_out, {}, {}) # trust that user obeyed instructions
73
+ if isinstance(parsed, dict):
74
+ # Only keep recognized keys, ignore anything else
75
+ new_state: AgentState = {"messages": chat_history}
76
+ allowed = {
77
+ "web_search_query",
78
+ "ocr_path",
79
+ "excel_path",
80
+ "excel_sheet_name",
81
+ "final_answer"
82
+ }
83
+ for k, v in parsed.items():
84
+ if k in allowed:
85
+ new_state[k] = v
86
+ return new_state
87
+ except Exception:
88
+ pass
89
+
90
+ # 4.d) If parsing failed, or they returned something else, set a fallback
91
+ return {
92
+ "messages": chat_history,
93
+ "final_answer": "Sorry, I could not parse your intent."
94
+ }
95
+
96
+ # ─── 5) Define “finalize” node: compose the final answer using any tool results ───
97
+ def finalize_node(state: AgentState) -> AgentState:
98
+ """
99
+ By this point:
100
+ - state['messages'] contains the chat history (ending with how we requested a tool).
101
+ - One or more of web_search_result, ocr_result, excel_result might be filled.
102
+ - Or, state['final_answer'] is already set, meaning no tool was needed.
103
+ We ask the LLM to produce a final text answer.
104
+ """
105
+ # 5.a) Build a prompt listing any tool results:
106
+ parts = state.get("messages", [])
107
+ if "web_search_result" in state and state["web_search_result"] is not None:
108
+ parts.append(f"WEB_SEARCH_RESULT: {state['web_search_result']}")
109
+ if "ocr_result" in state and state["ocr_result"] is not None:
110
+ parts.append(f"OCR_RESULT: {state['ocr_result']}")
111
+ if "excel_result" in state and state["excel_result"] is not None:
112
+ parts.append(f"EXCEL_RESULT: {state['excel_result']}")
113
+
114
+ parts.append("ASSISTANT: Please provide the final answer now.")
115
+ llm_out = llm(parts).content.strip()
116
+
117
+ return {"final_answer": llm_out}
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+ graph = StateGraph(AgentState)
130
+
131
+ # 6.a) Register nodes in order:
132
+ graph.add_node("plan", plan_node)
133
+ graph.add_node("tools", tool_node)
134
+ graph.add_node("finalize", finalize_node)
135
+
136
+ # 6.b) START → "plan"
137
+ graph.add_edge(START, "plan")
138
+
139
+ # 6.c) If plan_node sets a tool‐query key, go to "tools"; otherwise go to "finalize".
140
+ def route_plan(state: AgentState, plan_out: AgentState) -> str:
141
+ # If plan_node placed a "web_search_query", "ocr_path", or "excel_path", go to tools.
142
+ # (Note: plan_out already replaced state["messages"])
143
+ if plan_out.get("web_search_query") or plan_out.get("ocr_path") or plan_out.get("excel_path"):
144
+ return "tools"
145
+ return "finalize"
146
+
147
+ graph.add_conditional_edges(
148
+ "plan",
149
+ route_plan,
150
+ {"tools": "tools", "finalize": "finalize"}
151
+ )
152
+
153
+ def run_tools(state: AgentState, tool_out: AgentState) -> AgentState:
154
+ """
155
+ When a tool‐wrapper returns, it has already consumed the relevant key
156
+ (e.g. set web_search_query back to None) and added tool_result.
157
+ We just merge that into state.
158
+ """
159
+ new_state = {**state, **tool_out}
160
+ return new_state
161
+
162
+
163
+
164
+ graph.add_edge("tools", "finalize", run_tools)
165
+
166
+ # 6.e) "finalize" → END
167
+ graph.add_edge("finalize", END)
168
+
169
  compiled_graph = graph.compile()
170
 
171
+ # ─── 7) Define respond_to_input that drives the graph ───
172
  def respond_to_input(user_input: str) -> str:
173
+ # On first turn, messages=[], no query keys set.
174
+ initial_state: AgentState = {"messages": []}
175
+ final_state = compiled_graph.invoke(initial_state, user_input)
176
+ # final_state should have 'final_answer'
177
+ return final_state.get("final_answer", "Error: No final answer generated.")
 
 
 
 
 
 
 
 
 
 
 
 
178
 
 
 
 
 
 
 
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  class BasicAgent:
182
  def __init__(self):
tools.py CHANGED
@@ -1,78 +1,71 @@
1
- from langchain_core.tools import tool
2
- from langchain_community.tools import DuckDuckGoSearchRun
3
- import pandas as pd
4
- @tool
5
- def web_search(query: str) -> str:
6
- """
7
- Search the web for information.
8
- Args:
9
- query: The query to search the web for.
10
- Returns:
11
- The search results.
12
- """
13
- print(f"Reached: web_search: {query}")
14
- ddg = DuckDuckGoSearchRun()
15
- return ddg.run(query)
16
-
17
-
18
- @tool
19
- def parse_excel(path: str, sheet_name: str = None) -> str:
20
-
21
- """
22
- Read in an Excel file at `path`, optionally select a sheet by name (or default to the first sheet),
23
- then convert the DataFrame to a JSON-like string. Return that text so the LLM can reason over it.
24
-
25
- Example return value (collapsed):
26
- "[{'Name': 'Alice', 'Score': 95}, {'Name': 'Bob', 'Score': 88}, ...]"
27
- """
28
- # 1. Load the Excel workbook
29
- print(f"Reached: parse_excel: {path} {sheet_name}")
30
- try:
31
- xls = pd.ExcelFile(path)
32
- except FileNotFoundError:
33
- return f"Error: could not find file at {path}."
34
-
35
- # 2. Choose the sheet
36
- if sheet_name and sheet_name in xls.sheet_names:
37
- df = pd.read_excel(xls, sheet_name=sheet_name)
38
- else:
39
- # default to first sheet
40
- df = pd.read_excel(xls, sheet_name=xls.sheet_names[0])
41
-
42
- # 3. Option A: convert to JSON
43
- records = df.to_dict(orient="records")
44
- return str(records)
45
-
46
-
47
-
48
  # tools.py
49
 
 
 
50
  from pathlib import Path
51
  from PIL import Image
52
  import pytesseract
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- @tool
56
- def ocr_image(path: str) -> str:
57
  """
58
- Run OCR on the image at `path` and return the extracted text.
59
- - Expects that Tesseract is installed on the host machine.
60
- - If the file is missing or unreadable, returns an error string.
61
  """
62
- print(f"Reached: ocr_image: {path}")
63
- file = Path(path)
64
- if not file.exists():
65
- return f"Error: could not find image at {path}"
66
  try:
67
- # Open image via PIL
68
- img = Image.open(file)
 
69
  except Exception as e:
70
- return f"Error: could not open image: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  try:
73
- # Run pytesseract OCR
74
- text = pytesseract.image_to_string(img)
 
 
 
 
 
75
  except Exception as e:
76
- return f"Error: OCR failed: {e}"
77
-
78
- return text.strip() or "(no visible text detected)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # tools.py
2
 
3
+ import pandas as pd
4
+ from langchain_community.tools import DuckDuckGoSearchRun
5
  from pathlib import Path
6
  from PIL import Image
7
  import pytesseract
8
 
9
+ def web_search_tool(state: AgentState) -> AgentState:
10
+ """
11
+ Expects: state["web_search_query"] is a non‐empty string.
12
+ Returns: {"web_search_query": None, "web_search_result": <string>}
13
+ We also clear web_search_query so we don’t loop forever.
14
+ """
15
+ query = state.get("web_search_query", "")
16
+ if not query:
17
+ return {} # nothing to do
18
+
19
+ # Run DuckDuckGo
20
+ ddg = DuckDuckGoSearchRun()
21
+ result_text = ddg.run(query)
22
+ return {
23
+ "web_search_query": None,
24
+ "web_search_result": result_text
25
+ }
26
 
27
+ def ocr_image_tool(state: AgentState) -> AgentState:
 
28
  """
29
+ Expects: state["ocr_path"] is a path to an image file.
30
+ Returns: {"ocr_path": None, "ocr_result": <string>}.
 
31
  """
32
+ path = state.get("ocr_path", "")
33
+ if not path:
34
+ return {}
 
35
  try:
36
+ img = Image.open(path)
37
+ text = pytesseract.image_to_string(img)
38
+ text = text.strip() or "(no visible text)"
39
  except Exception as e:
40
+ text = f"Error during OCR: {e}"
41
+ return {
42
+ "ocr_path": None,
43
+ "ocr_result": text
44
+ }
45
+
46
+ def parse_excel_tool(state: AgentState) -> AgentState:
47
+ """
48
+ Expects: state["excel_path"] is a path to an .xlsx file,
49
+ and state["excel_sheet_name"] optionally names a sheet.
50
+ Returns: {"excel_path": None, "excel_sheet_name": None, "excel_result": <string>}.
51
+ """
52
+ path = state.get("excel_path", "")
53
+ sheet = state.get("excel_sheet_name", "")
54
+ if not path:
55
+ return {}
56
 
57
  try:
58
+ xls = pd.ExcelFile(path)
59
+ if sheet and sheet in xls.sheet_names:
60
+ df = pd.read_excel(xls, sheet_name=sheet)
61
+ else:
62
+ df = pd.read_excel(xls, sheet_name=xls.sheet_names[0])
63
+ records = df.to_dict(orient="records")
64
+ text = str(records)
65
  except Exception as e:
66
+ text = f"Error reading Excel: {e}"
67
+ return {
68
+ "excel_path": None,
69
+ "excel_sheet_name": None,
70
+ "excel_result": text
71
+ }