naman1102 commited on
Commit
5166389
Β·
1 Parent(s): e339dd2
Files changed (3) hide show
  1. app.py +6 -6
  2. state.py +2 -1
  3. tools.py +6 -6
app.py CHANGED
@@ -53,7 +53,7 @@ def plan_node(state: AgentState) -> AgentState:
53
  human_msg = HumanMessage(content=user_input)
54
 
55
  # (2) Call the LLM
56
- llm_response = llm([system_msg, human_msg])
57
  llm_out = llm_response.content.strip()
58
 
59
  # ── DEBUG: print raw LLM output ──
@@ -254,7 +254,7 @@ compiled_graph = graph.compile()
254
 
255
 
256
  # ─── 6) respond_to_input ───
257
- def respond_to_input(user_input: str) -> str:
258
  """
259
  Seed state['messages'] with a SystemMessage (tools description) + HumanMessage(user_input).
260
  Then invoke the graph; return the final_answer from the resulting state.
@@ -276,7 +276,7 @@ def respond_to_input(user_input: str) -> str:
276
  )
277
  human_msg = HumanMessage(content=user_input)
278
 
279
- initial_state: AgentState = {"messages": [system_msg, human_msg]}
280
  final_state = compiled_graph.invoke(initial_state)
281
  return final_state.get("final_answer", "Error: No final answer generated.")
282
 
@@ -286,7 +286,7 @@ def respond_to_input(user_input: str) -> str:
286
  class BasicAgent:
287
  def __init__(self):
288
  print("BasicAgent initialized.")
289
- def __call__(self, question: str) -> str:
290
  # print(f"Agent received question (first 50 chars): {question[:50]}...")
291
  # fixed_answer = "This is a default answer."
292
  # print(f"Agent returning fixed answer: {fixed_answer}")
@@ -298,7 +298,7 @@ class BasicAgent:
298
 
299
  print(f"Agent received question: {question}")
300
  print()
301
- return respond_to_input(question)
302
  # return fixed_answer
303
 
304
 
@@ -368,7 +368,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
368
  print(f"Skipping item with missing task_id or question: {item}")
369
  continue
370
  try:
371
- submitted_answer = agent(question_text)
372
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
373
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
374
  except Exception as e:
 
53
  human_msg = HumanMessage(content=user_input)
54
 
55
  # (2) Call the LLM
56
+ llm_response = llm.invoke([system_msg, human_msg])
57
  llm_out = llm_response.content.strip()
58
 
59
  # ── DEBUG: print raw LLM output ──
 
254
 
255
 
256
  # ─── 6) respond_to_input ───
257
+ def respond_to_input(user_input: str, task_id) -> str:
258
  """
259
  Seed state['messages'] with a SystemMessage (tools description) + HumanMessage(user_input).
260
  Then invoke the graph; return the final_answer from the resulting state.
 
276
  )
277
  human_msg = HumanMessage(content=user_input)
278
 
279
+ initial_state: AgentState = {"messages": [system_msg, human_msg], "task_id": task_id}
280
  final_state = compiled_graph.invoke(initial_state)
281
  return final_state.get("final_answer", "Error: No final answer generated.")
282
 
 
286
  class BasicAgent:
287
  def __init__(self):
288
  print("BasicAgent initialized.")
289
+ def __call__(self, question: str, task_id) -> str:
290
  # print(f"Agent received question (first 50 chars): {question[:50]}...")
291
  # fixed_answer = "This is a default answer."
292
  # print(f"Agent returning fixed answer: {fixed_answer}")
 
298
 
299
  print(f"Agent received question: {question}")
300
  print()
301
+ return respond_to_input(question, task_id)
302
  # return fixed_answer
303
 
304
 
 
368
  print(f"Skipping item with missing task_id or question: {item}")
369
  continue
370
  try:
371
+ submitted_answer = agent(question_text, task_id)
372
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
373
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
374
  except Exception as e:
state.py CHANGED
@@ -17,4 +17,5 @@ class AgentState(TypedDict, total=False):
17
  transcript: str
18
  audio_transcript: str
19
  wiki_query: str
20
- wiki_result: str
 
 
17
  transcript: str
18
  audio_transcript: str
19
  wiki_query: str
20
+ wiki_result: str
21
+ task_id: str
tools.py CHANGED
@@ -98,14 +98,14 @@ def ocr_image_tool(state: AgentState) -> AgentState:
98
  Always attempts to download the file for the given path or task ID.
99
  """
100
  print("reached ocr_image_tool")
101
- path_or_id = state.get("ocr_path", "")
102
- if not path_or_id:
103
- return {}
104
 
105
  # Always attempt to download the file, regardless of local existence
106
  local_img = ""
107
  for ext in ("png", "jpg", "jpeg"):
108
- candidate = _download_file_for_task(path_or_id, ext)
109
  if candidate:
110
  local_img = candidate
111
  break
@@ -149,7 +149,7 @@ def parse_excel_tool(state: AgentState) -> AgentState:
149
  return {}
150
 
151
  # Always attempt to download the file, regardless of local existence
152
- local_xlsx = _download_file_for_task(path_or_id, "xlsx")
153
 
154
  # If we finally have a real file, read it
155
  if local_xlsx and os.path.exists(local_xlsx):
@@ -239,7 +239,7 @@ def audio_transcriber_tool(state: AgentState) -> AgentState:
239
  # Always attempt to download the file, regardless of local existence
240
  local_audio = ""
241
  for ext in ("mp3", "wav", "m4a"):
242
- candidate = _download_file_for_task(path_or_id, ext)
243
  if candidate:
244
  local_audio = candidate
245
  break
 
98
  Always attempts to download the file for the given path or task ID.
99
  """
100
  print("reached ocr_image_tool")
101
+ # path_or_id = state.get("ocr_path", "")
102
+ # if not path_or_id:
103
+ # return {}
104
 
105
  # Always attempt to download the file, regardless of local existence
106
  local_img = ""
107
  for ext in ("png", "jpg", "jpeg"):
108
+ candidate = _download_file_for_task(state.get("task_id"), ext)
109
  if candidate:
110
  local_img = candidate
111
  break
 
149
  return {}
150
 
151
  # Always attempt to download the file, regardless of local existence
152
+ local_xlsx = _download_file_for_task(state.get("task_id"), "xlsx")
153
 
154
  # If we finally have a real file, read it
155
  if local_xlsx and os.path.exists(local_xlsx):
 
239
  # Always attempt to download the file, regardless of local existence
240
  local_audio = ""
241
  for ext in ("mp3", "wav", "m4a"):
242
+ candidate = _download_file_for_task(state.get("task_id"), ext)
243
  if candidate:
244
  local_audio = candidate
245
  break