naman1102 commited on
Commit
c7a6db7
Β·
1 Parent(s): 1f5cba5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -11
app.py CHANGED
@@ -13,11 +13,7 @@ from langgraph.graph import StateGraph, START, END
13
  from langgraph.graph.message import add_messages
14
 
15
  # Create a ToolNode that knows about your web_search function
16
- search_node = ToolNode([web_search])
17
 
18
- excel_tool_node = ToolNode([parse_excel])
19
-
20
- image_tool_node = ToolNode([ocr_image])
21
  # (Keep Constants as is)
22
  # --- Constants ---
23
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
@@ -26,9 +22,79 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
26
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
27
 
28
 
29
-
30
-
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
 
@@ -36,10 +102,11 @@ class BasicAgent:
36
  def __init__(self):
37
  print("BasicAgent initialized.")
38
  def __call__(self, question: str) -> str:
39
- print(f"Agent received question (first 50 chars): {question[:50]}...")
40
- fixed_answer = "This is a default answer."
41
- print(f"Agent returning fixed answer: {fixed_answer}")
42
- return fixed_answer
 
43
 
44
 
45
 
 
13
  from langgraph.graph.message import add_messages
14
 
15
  # Create a ToolNode that knows about your web_search function
 
16
 
 
 
 
17
  # (Keep Constants as is)
18
  # --- Constants ---
19
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
22
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
23
 
24
 
25
+ class AgentState(TypedDict):
26
+ messages: Annotated[list, add_messages]
27
+ question: str
28
+ answer: str
29
+
30
+ llm = ChatOpenAI(model_name="gpt-4.1-mini")
31
+ llm_node = LLMNode(llm)
32
+
33
+ # 4) Wrap the tools in a single ToolNode
34
+ # You can list as many @tool functions here as you like.
35
+ # search_node = ToolNode([web_search])
36
+
37
+ # excel_tool_node = ToolNode([parse_excel])
38
+
39
+ # image_tool_node = ToolNode([ocr_image])
40
+ tool_node = ToolNode([ocr_image, parse_excel, web_search])
41
+
42
+ # 5) Build the StateGraph
43
+ graph = StateGraph[AgentState]()
44
+
45
+ # ────────────────────────
46
+ # Edge 1: START β†’ LLM
47
+ # Wrap the user_input into state.messages
48
+ graph.add_edge(
49
+ START,
50
+ llm_node,
51
+ lambda state, user_input: {"messages": [user_input]},
52
+ name="start_to_llm",
53
+ )
54
+
55
+ # ────────────────────────
56
+ # Edge 2: LLM β†’ ToolNode
57
+ # Only fire when the LLM returns a dict with exactly "tool":"ocr_image" or "tool":"parse_excel"
58
+ # The lambda must return that dict so the ToolNode can extract its arguments.
59
+ def route_to_tool(state: AgentState, llm_out):
60
+ # Expecting llm_out to be a dict like:
61
+ # {"tool": "ocr_image", "path": "invoice.png"}
62
+ # or {"tool": "parse_excel", "path":"sales.xlsx", "sheet_name":"Sheet1"}
63
+ if isinstance(llm_out, dict) and llm_out.get("tool") in {"ocr_image", "parse_excel", "web_search"}:
64
+ return llm_out
65
+ return None # β†’ do not invoke the tool if it's not matching
66
+
67
+ graph.add_edge(
68
+ llm_node,
69
+ tool_node,
70
+ route_to_tool,
71
+ name="llm_to_tool",
72
+ )
73
+
74
+ # ────────────────────────
75
+ # Edge 3: ToolNode β†’ LLM
76
+ # Whatever the tool returns (a string), feed that straight back into the LLM as the next turn.
77
+ graph.add_edge(
78
+ tool_node,
79
+ llm_node,
80
+ lambda state, tool_out: tool_out,
81
+ name="tool_to_llm",
82
+ )
83
+
84
+ # ────────────────────────
85
+ # Edge 4: LLM β†’ END
86
+ # Once the LLM is done reasoning (without calling any more tools), return its output string.
87
+ graph.add_edge(
88
+ llm_node,
89
+ END,
90
+ lambda state, final_str: final_str,
91
+ name="llm_to_end",
92
+ )
93
+
94
+
95
+ def respond_to_input(user_input: str) -> str:
96
+ initial_state: AgentState = {"messages": []}
97
+ return graph.run(initial_state, user_input)
98
 
99
 
100
 
 
102
  def __init__(self):
103
  print("BasicAgent initialized.")
104
  def __call__(self, question: str) -> str:
105
+ # print(f"Agent received question (first 50 chars): {question[:50]}...")
106
+ # fixed_answer = "This is a default answer."
107
+ # print(f"Agent returning fixed answer: {fixed_answer}")
108
+ return respond_to_input(question)
109
+ # return fixed_answer
110
 
111
 
112