Spaces:
Sleeping
Sleeping
EtienneB
commited on
Commit
·
fc5e0c3
1
Parent(s):
916fd5c
Update agent.py
Browse files
agent.py
CHANGED
@@ -3,7 +3,8 @@ import os
|
|
3 |
import re
|
4 |
|
5 |
from dotenv import load_dotenv
|
6 |
-
from langchain_core.messages import HumanMessage, SystemMessage,
|
|
|
7 |
from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
|
8 |
HuggingFaceEndpoint)
|
9 |
from langgraph.graph import START, MessagesState, StateGraph
|
@@ -71,18 +72,25 @@ def build_graph():
|
|
71 |
llm_with_tools = llm.bind_tools(tools)
|
72 |
|
73 |
# --- Nodes ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
def assistant(state: MessagesState):
|
75 |
-
"""Assistant node"""
|
76 |
messages_with_system_prompt = [sys_msg] + state["messages"]
|
77 |
llm_response = llm_with_tools.invoke(messages_with_system_prompt)
|
78 |
-
|
79 |
-
|
80 |
-
if answer_text.strip().lower().startswith("final answer:"):
|
81 |
-
answer_text = answer_text.split(":", 1)[1].strip()
|
82 |
-
# Get task_id from state or set a placeholder
|
83 |
-
task_id = state.get("task_id", "1") # Replace with actual logic if needed
|
84 |
formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
|
85 |
-
return {"messages": [formatted]}
|
86 |
|
87 |
# --- Graph Definition ---
|
88 |
builder = StateGraph(MessagesState)
|
@@ -154,9 +162,12 @@ if __name__ == "__main__":
|
|
154 |
print(message.content)
|
155 |
print("-----------------------")
|
156 |
else:
|
157 |
-
output =
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
3 |
import re
|
4 |
|
5 |
from dotenv import load_dotenv
|
6 |
+
from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage,
|
7 |
+
ToolMessage)
|
8 |
from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
|
9 |
HuggingFaceEndpoint)
|
10 |
from langgraph.graph import START, MessagesState, StateGraph
|
|
|
72 |
llm_with_tools = llm.bind_tools(tools)
|
73 |
|
74 |
# --- Nodes ---
|
75 |
+
def extract_answer(llm_output):
|
76 |
+
# Try to parse as JSON if possible
|
77 |
+
try:
|
78 |
+
# If the LLM output is a JSON list, extract the answer
|
79 |
+
parsed = json.loads(llm_output.strip().split('\n')[0])
|
80 |
+
if isinstance(parsed, list) and isinstance(parsed[0], dict) and "submitted_answer" in parsed[0]:
|
81 |
+
return parsed[0]["submitted_answer"]
|
82 |
+
except Exception:
|
83 |
+
pass
|
84 |
+
# Otherwise, just return the first line (before any explanation)
|
85 |
+
return llm_output.strip().split('\n')[0]
|
86 |
+
|
87 |
def assistant(state: MessagesState):
|
|
|
88 |
messages_with_system_prompt = [sys_msg] + state["messages"]
|
89 |
llm_response = llm_with_tools.invoke(messages_with_system_prompt)
|
90 |
+
answer_text = extract_answer(llm_response.content)
|
91 |
+
task_id = str(state.get("task_id", "1")) # Ensure task_id is a string
|
|
|
|
|
|
|
|
|
92 |
formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
|
93 |
+
return {"messages": [AIMessage(content=json.dumps(formatted, ensure_ascii=False))]}
|
94 |
|
95 |
# --- Graph Definition ---
|
96 |
builder = StateGraph(MessagesState)
|
|
|
162 |
print(message.content)
|
163 |
print("-----------------------")
|
164 |
else:
|
165 |
+
output = message.content # This is a string
|
166 |
+
try:
|
167 |
+
parsed = json.loads(output)
|
168 |
+
if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
|
169 |
+
print("✅ Output is in the correct format!")
|
170 |
+
else:
|
171 |
+
print("❌ Output is NOT in the correct format!")
|
172 |
+
except Exception as e:
|
173 |
+
print("❌ Output is NOT in the correct format!", e)
|