EtienneB commited on
Commit
fc5e0c3
·
1 Parent(s): 916fd5c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +26 -15
agent.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  import re
4
 
5
  from dotenv import load_dotenv
6
- from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
 
7
  from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
8
  HuggingFaceEndpoint)
9
  from langgraph.graph import START, MessagesState, StateGraph
@@ -71,18 +72,25 @@ def build_graph():
71
  llm_with_tools = llm.bind_tools(tools)
72
 
73
  # --- Nodes ---
 
 
 
 
 
 
 
 
 
 
 
 
74
  def assistant(state: MessagesState):
75
- """Assistant node"""
76
  messages_with_system_prompt = [sys_msg] + state["messages"]
77
  llm_response = llm_with_tools.invoke(messages_with_system_prompt)
78
- # Extract the answer text (strip any "FINAL ANSWER:" if present)
79
- answer_text = llm_response.content
80
- if answer_text.strip().lower().startswith("final answer:"):
81
- answer_text = answer_text.split(":", 1)[1].strip()
82
- # Get task_id from state or set a placeholder
83
- task_id = state.get("task_id", "1") # Replace with actual logic if needed
84
  formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
85
- return {"messages": [formatted]}
86
 
87
  # --- Graph Definition ---
88
  builder = StateGraph(MessagesState)
@@ -154,9 +162,12 @@ if __name__ == "__main__":
154
  print(message.content)
155
  print("-----------------------")
156
  else:
157
- output = str(message)
158
- print("Agent Output:", output)
159
- if is_valid_agent_output(output):
160
- print("✅ Output is in the correct format!")
161
- else:
162
- print("❌ Output is NOT in the correct format!")
 
 
 
 
3
  import re
4
 
5
  from dotenv import load_dotenv
6
+ from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage,
7
+ ToolMessage)
8
  from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
9
  HuggingFaceEndpoint)
10
  from langgraph.graph import START, MessagesState, StateGraph
 
72
  llm_with_tools = llm.bind_tools(tools)
73
 
74
  # --- Nodes ---
75
+ def extract_answer(llm_output):
76
+ # Try to parse as JSON if possible
77
+ try:
78
+ # If the LLM output is a JSON list, extract the answer
79
+ parsed = json.loads(llm_output.strip().split('\n')[0])
80
+ if isinstance(parsed, list) and isinstance(parsed[0], dict) and "submitted_answer" in parsed[0]:
81
+ return parsed[0]["submitted_answer"]
82
+ except Exception:
83
+ pass
84
+ # Otherwise, just return the first line (before any explanation)
85
+ return llm_output.strip().split('\n')[0]
86
+
87
  def assistant(state: MessagesState):
 
88
  messages_with_system_prompt = [sys_msg] + state["messages"]
89
  llm_response = llm_with_tools.invoke(messages_with_system_prompt)
90
+ answer_text = extract_answer(llm_response.content)
91
+ task_id = str(state.get("task_id", "1")) # Ensure task_id is a string
 
 
 
 
92
  formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
93
+ return {"messages": [AIMessage(content=json.dumps(formatted, ensure_ascii=False))]}
94
 
95
  # --- Graph Definition ---
96
  builder = StateGraph(MessagesState)
 
162
  print(message.content)
163
  print("-----------------------")
164
  else:
165
+ output = message.content # This is a string
166
+ try:
167
+ parsed = json.loads(output)
168
+ if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
169
+ print("✅ Output is in the correct format!")
170
+ else:
171
+ print("❌ Output is NOT in the correct format!")
172
+ except Exception as e:
173
+ print("❌ Output is NOT in the correct format!", e)