File size: 4,889 Bytes
94b3868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import re
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_core.messages import HumanMessage, SystemMessage
from huggingface_hub import InferenceClient
from custom_tools import TOOLS
from langchain_core.messages import AIMessage

HF_TOKEN = os.getenv("HUGGINGFACE_API_TOKEN")
client = InferenceClient(token=HF_TOKEN)

planner_prompt = SystemMessage(content="""

    You are a planning assistant. Your job is to decide how to answer a question.



    - If the answer is easy and factual, answer it directly.

    - If you are not 100% certain or the answer requires looking up real-world information, say:

        I need to search this.



    - If the question contains math or expressions like +, -, /, ^, say:

        I need to calculate this.



    - If a word should be explained, say:

        I need to define this.

    

    -If the question asks about a person, historical event, or specific topic, say: 

        I need to look up wikipedia.

    

    -If the questions asks for backwards pronounciation or reversing text, say:

        I need to reverse text.



    Only respond with one line explaining what you will do.

    Do not try to answer yet.

                               

    e.g: 

        Q: How many studio albums did Mercedes Sosa release between 2000 and 2009?

        A: I need to search this.



        Q: What does the word 'ephemeral' mean?

        A: I need to define this.



        Q: What is 23 * 6 + 3?

        A: I need to calculate this.



        Q: Reverse this: 'tfel drow eht'

        A: I need to reverse text.



        Q: What bird species are seen in this video?

        A: UNKNOWN

    """)

def planner_node(state: MessagesState):
    hf_messages = [planner_prompt] + state["messages"]

    # Properly map LangChain message objects to dicts
    messages_dict = []
    for msg in hf_messages:
        if isinstance(msg, SystemMessage):
            role = "system"
        elif isinstance(msg, HumanMessage):
            role = "user"
        else:
            raise ValueError(f"Unsupported message type: {type(msg)}")
        messages_dict.append({"role": role, "content": msg.content})

    response = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        messages=messages_dict,
    )

    text = response.choices[0].message.content.strip()
    print("Planner output:\n", text)

    return {"messages": [SystemMessage(content=text)]}

answer_prompt = SystemMessage(content="""

    You are now given the result of a tool (like a search, calculator, or text reversal).

    Use the tool result and the original question to give the final answer.

    If the tool result is unhelpful or unclear, respond with 'UNKNOWN'.

    Respond with only the answer β€” no explanations.

    """)

def assistant_node(state: MessagesState):
    hf_messages = [answer_prompt] + state["messages"]

    messages_dict = []
    for msg in hf_messages:
        if isinstance(msg, SystemMessage):
            role = "system"
        elif isinstance(msg, HumanMessage):
            role = "user"
        else:
            raise ValueError(f"Unsupported message type: {type(msg)}")
        messages_dict.append({"role": role, "content": msg.content})

    response = client.chat.completions.create(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        messages=messages_dict,
    )

    text = response.choices[0].message.content.strip()
    print("Final answer output:\n", text)

    return {"messages": [AIMessage(content=text)]}

def tools_condition(state: MessagesState) -> str:
    last_msg = state["messages"][-1].content.lower()

    if any(trigger in last_msg for trigger in [
        "i need to search",
        "i need to calculate",
        "i need to define",
        "i need to reverse text",
        "i need to look up wikipedia"
    ]):
        return "tools"

    return "end"

class PatchedToolNode(ToolNode):
    def invoke(self, state: MessagesState, config) -> dict:
        result = super().invoke(state)
        tool_output = result.get("messages", [])[0].content if result.get("messages") else "UNKNOWN"

        # Append tool result as a HumanMessage so assistant sees it
        new_messages = state["messages"] + [HumanMessage(content=f"Tool result:\n{tool_output}")]
        return {"messages": new_messages}
    
def build_graph():
    builder = StateGraph(MessagesState)

    builder.add_node("planner", planner_node)
    builder.add_node("assistant", assistant_node)
    builder.add_node("tools", PatchedToolNode(TOOLS))

    builder.add_edge(START, "planner")
    builder.add_conditional_edges("planner", tools_condition)
    builder.add_edge("tools", "assistant")

    return builder.compile()