assignment_agent / functions.py
Arbnor Tefiki
First commit
94b3868
raw
history blame
4.89 kB
import os
import re
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_core.messages import HumanMessage, SystemMessage
from huggingface_hub import InferenceClient
from custom_tools import TOOLS
from langchain_core.messages import AIMessage
HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
client = InferenceClient(token=HF_TOKEN)
planner_prompt = SystemMessage(content="""
You are a planning assistant. Your job is to decide how to answer a question.
- If the answer is easy and factual, answer it directly.
- If you are not 100% certain or the answer requires looking up real-world information, say:
I need to search this.
- If the question contains math or expressions like +, -, /, ^, say:
I need to calculate this.
- If a word should be explained, say:
I need to define this.
-If the question asks about a person, historical event, or specific topic, say:
I need to look up wikipedia.
-If the questions asks for backwards pronounciation or reversing text, say:
I need to reverse text.
Only respond with one line explaining what you will do.
Do not try to answer yet.
e.g:
Q: How many studio albums did Mercedes Sosa release between 2000 and 2009?
A: I need to search this.
Q: What does the word 'ephemeral' mean?
A: I need to define this.
Q: What is 23 * 6 + 3?
A: I need to calculate this.
Q: Reverse this: 'tfel drow eht'
A: I need to reverse text.
Q: What bird species are seen in this video?
A: UNKNOWN
""")
def planner_node(state: MessagesState):
hf_messages = [planner_prompt] + state["messages"]
# Properly map LangChain message objects to dicts
messages_dict = []
for msg in hf_messages:
if isinstance(msg, SystemMessage):
role = "system"
elif isinstance(msg, HumanMessage):
role = "user"
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
messages_dict.append({"role": role, "content": msg.content})
response = client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=messages_dict,
)
text = response.choices[0].message.content.strip()
print("Planner output:\n", text)
return {"messages": [SystemMessage(content=text)]}
answer_prompt = SystemMessage(content="""
You are now given the result of a tool (like a search, calculator, or text reversal).
Use the tool result and the original question to give the final answer.
If the tool result is unhelpful or unclear, respond with 'UNKNOWN'.
Respond with only the answer β€” no explanations.
""")
def assistant_node(state: MessagesState):
hf_messages = [answer_prompt] + state["messages"]
messages_dict = []
for msg in hf_messages:
if isinstance(msg, SystemMessage):
role = "system"
elif isinstance(msg, HumanMessage):
role = "user"
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
messages_dict.append({"role": role, "content": msg.content})
response = client.chat.completions.create(
model="mistralai/Mistral-7B-Instruct-v0.2",
messages=messages_dict,
)
text = response.choices[0].message.content.strip()
print("Final answer output:\n", text)
return {"messages": [AIMessage(content=text)]}
def tools_condition(state: MessagesState) -> str:
last_msg = state["messages"][-1].content.lower()
if any(trigger in last_msg for trigger in [
"i need to search",
"i need to calculate",
"i need to define",
"i need to reverse text",
"i need to look up wikipedia"
]):
return "tools"
return "end"
class PatchedToolNode(ToolNode):
def invoke(self, state: MessagesState, config) -> dict:
result = super().invoke(state)
tool_output = result.get("messages", [])[0].content if result.get("messages") else "UNKNOWN"
# Append tool result as a HumanMessage so assistant sees it
new_messages = state["messages"] + [HumanMessage(content=f"Tool result:\n{tool_output}")]
return {"messages": new_messages}
def build_graph():
builder = StateGraph(MessagesState)
builder.add_node("planner", planner_node)
builder.add_node("assistant", assistant_node)
builder.add_node("tools", PatchedToolNode(TOOLS))
builder.add_edge(START, "planner")
builder.add_conditional_edges("planner", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()