wenzel94 commited on
Commit
a68ce23
·
verified ·
1 Parent(s): 81917a3

build_graph

Browse files
Files changed (1) hide show
  1. graph.py +161 -0
graph.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Build a graph to solve gala problems.
3
+ """
4
+ import os
5
+ from langchain_core.tools import tool
6
+ from langchain_groq import ChatGroq
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langgraph.graph import START, StateGraph, MessagesState, END
9
+ from langgraph.prebuilt import tools_condition, ToolNode
10
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
11
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
+ from langchain_core.messages import SystemMessage, HumanMessage
13
+
14
+ @tool
15
+ def multiply(a: int, b: int) -> int:
16
+ """Multiply two numbers.
17
+ Args:
18
+ a: first int
19
+ b: second int
20
+ """
21
+ return a * b
22
+
23
+ @tool
24
+ def add(a: int, b: int) -> int:
25
+ """Add two numbers.
26
+
27
+ Args:
28
+ a: first int
29
+ b: second int
30
+ """
31
+ return a + b
32
+
33
+ @tool
34
+ def subtract(a: int, b: int) -> int:
35
+ """Subtract two numbers.
36
+
37
+ Args:
38
+ a: first int
39
+ b: second int
40
+ """
41
+ return a - b
42
+
43
+ @tool
44
+ def divide(a: int, b: int) -> int:
45
+ """Divide two numbers.
46
+
47
+ Args:
48
+ a: first int
49
+ b: second int
50
+ """
51
+ if b == 0:
52
+ raise ValueError("Cannot divide by zero.")
53
+ return a / b
54
+
55
+ @tool
56
+ def modulus(a: int, b: int) -> int:
57
+ """Get the modulus of two numbers.
58
+
59
+ Args:
60
+ a: first int
61
+ b: second int
62
+ """
63
+ return a % b
64
+
65
+ @tool
66
+ def wiki_search(query: str) -> str:
67
+ """Search Wikipedia for a query and return maximum 2 results.
68
+
69
+ Args:
70
+ query: The search query."""
71
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
72
+ formatted_search_docs = "\n\n---\n\n".join(
73
+ [
74
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
75
+ for doc in search_docs
76
+ ])
77
+ return {"wiki_results": formatted_search_docs}
78
+
79
+ @tool
80
+ def web_search(query: str) -> str:
81
+ """Search Tavily for a query and return maximum 3 results.
82
+
83
+ Args:
84
+ query: The search query."""
85
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
86
+ formatted_search_docs = "\n\n---\n\n".join(
87
+ [
88
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
89
+ for doc in search_docs
90
+ ])
91
+ return {"web_results": formatted_search_docs}
92
+
93
+ @tool
94
+ def arvix_search(query: str) -> str:
95
+ """Search Arxiv for a query and return maximum 3 result.
96
+
97
+ Args:
98
+ query: The search query."""
99
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
100
+ formatted_search_docs = "\n\n---\n\n".join(
101
+ [
102
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
103
+ for doc in search_docs
104
+ ])
105
+ return {"arvix_results": formatted_search_docs}
106
+
107
+ SYSTEM_PROMPT = """
108
+ You are a helpful assistant that can solve problems using a set of tools.
109
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
110
+ FINAL ANSWER: [YOUR FINAL ANSWER].
111
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
112
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
113
+ """
114
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
115
+ def build_graph():
116
+ """
117
+ Build a graph to solve gala problems.
118
+ """
119
+ model = HuggingFaceEndpoint(
120
+ repo_id="Qwen/QwQ-32B", # 模型ID
121
+ task="text-generation", # 任务类型
122
+ temperature=0.7,
123
+ max_new_tokens=512,
124
+ huggingfacehub_api_token=os.getenv('HUGGINGFACE_API_TOKEN'),
125
+ top_p=0.95,
126
+
127
+ )
128
+ llm = ChatHuggingFace(llm=model)
129
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
130
+ llm_with_tools = llm.bind_tools(tools)
131
+
132
+ # Node
133
+ def assistant(state: MessagesState):
134
+ """Assistant node"""
135
+ print(state["messages"])
136
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
137
+
138
+ def end(state: MessagesState):
139
+ """End node"""
140
+ return {"messages": [HumanMessage(content="FINAL ANSWER: " + state["messages"][-1].content)]}
141
+
142
+ builder = StateGraph(MessagesState)
143
+ builder.add_node("assistant", assistant)
144
+ builder.add_node("tools", ToolNode(tools))
145
+ builder.add_edge(START, "assistant")
146
+ builder.add_conditional_edges(
147
+ "assistant",
148
+ tools_condition
149
+ )
150
+ builder.add_edge("tools", "assistant")
151
+ builder.add_edge("assistant", END)
152
+ # Compile graph
153
+ return builder.compile()
154
+
155
+ if __name__ == "__main__":
156
+ from pprint import pprint
157
+ graph = build_graph()
158
+ messages = [SystemMessage(content=SYSTEM_PROMPT), HumanMessage(content="What is the capital of France?")]
159
+ msg = graph.invoke({"messages": messages})
160
+ for m in msg["messages"]:
161
+ m.pretty_print()