Spaces:
Sleeping
Sleeping
import os | |
import re | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import ToolNode | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from huggingface_hub import InferenceClient | |
from custom_tools import TOOLS | |
from langchain_core.messages import AIMessage | |
HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN") | |
client = InferenceClient(token=HF_TOKEN) | |
planner_prompt = SystemMessage(content=""" | |
You are a planning assistant. Your job is to decide how to answer a question. | |
- If the answer is easy and factual, answer it directly. | |
- If you are not 100% certain or the answer requires looking up real-world information, say: | |
I need to search this. | |
- If the question contains math or expressions like +, -, /, ^, say: | |
I need to calculate this. | |
- If a word should be explained, say: | |
I need to define this. | |
-If the question asks about a person, historical event, or specific topic, say: | |
I need to look up wikipedia. | |
-If the questions asks for backwards pronounciation or reversing text, say: | |
I need to reverse text. | |
Only respond with one line explaining what you will do. | |
Do not try to answer yet. | |
e.g: | |
Q: How many studio albums did Mercedes Sosa release between 2000 and 2009? | |
A: I need to search this. | |
Q: What does the word 'ephemeral' mean? | |
A: I need to define this. | |
Q: What is 23 * 6 + 3? | |
A: I need to calculate this. | |
Q: Reverse this: 'tfel drow eht' | |
A: I need to reverse text. | |
Q: What bird species are seen in this video? | |
A: UNKNOWN | |
""") | |
def planner_node(state: MessagesState): | |
hf_messages = [planner_prompt] + state["messages"] | |
# Properly map LangChain message objects to dicts | |
messages_dict = [] | |
for msg in hf_messages: | |
if isinstance(msg, SystemMessage): | |
role = "system" | |
elif isinstance(msg, HumanMessage): | |
role = "user" | |
else: | |
raise ValueError(f"Unsupported message type: {type(msg)}") | |
messages_dict.append({"role": role, "content": msg.content}) | |
response = client.chat.completions.create( | |
model="mistralai/Mistral-7B-Instruct-v0.2", | |
messages=messages_dict, | |
) | |
text = response.choices[0].message.content.strip() | |
print("Planner output:\n", text) | |
return {"messages": [SystemMessage(content=text)]} | |
answer_prompt = SystemMessage(content=""" | |
You are now given the result of a tool (like a search, calculator, or text reversal). | |
Use the tool result and the original question to give the final answer. | |
If the tool result is unhelpful or unclear, respond with 'UNKNOWN'. | |
Respond with only the answer β no explanations. | |
""") | |
def assistant_node(state: MessagesState): | |
hf_messages = [answer_prompt] + state["messages"] | |
messages_dict = [] | |
for msg in hf_messages: | |
if isinstance(msg, SystemMessage): | |
role = "system" | |
elif isinstance(msg, HumanMessage): | |
role = "user" | |
else: | |
raise ValueError(f"Unsupported message type: {type(msg)}") | |
messages_dict.append({"role": role, "content": msg.content}) | |
response = client.chat.completions.create( | |
model="mistralai/Mistral-7B-Instruct-v0.2", | |
messages=messages_dict, | |
) | |
text = response.choices[0].message.content.strip() | |
print("Final answer output:\n", text) | |
return {"messages": [AIMessage(content=text)]} | |
def tools_condition(state: MessagesState) -> str: | |
last_msg = state["messages"][-1].content.lower() | |
if any(trigger in last_msg for trigger in [ | |
"i need to search", | |
"i need to calculate", | |
"i need to define", | |
"i need to reverse text", | |
"i need to look up wikipedia" | |
]): | |
return "tools" | |
return "end" | |
class PatchedToolNode(ToolNode): | |
def invoke(self, state: MessagesState, config) -> dict: | |
result = super().invoke(state) | |
tool_output = result.get("messages", [])[0].content if result.get("messages") else "UNKNOWN" | |
# Append tool result as a HumanMessage so assistant sees it | |
new_messages = state["messages"] + [HumanMessage(content=f"Tool result:\n{tool_output}")] | |
return {"messages": new_messages} | |
def build_graph(): | |
builder = StateGraph(MessagesState) | |
builder.add_node("planner", planner_node) | |
builder.add_node("assistant", assistant_node) | |
builder.add_node("tools", PatchedToolNode(TOOLS)) | |
builder.add_edge(START, "planner") | |
builder.add_conditional_edges("planner", tools_condition) | |
builder.add_edge("tools", "assistant") | |
return builder.compile() |