Create agent.py

#177
by martinthetechie - opened
Files changed (1) hide show
  1. agent.py +199 -0
agent.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain.agents import initialize_agent, Tool
10
+ from langchain_groq import ChatGroq
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
+ from langchain_community.tools.tavily_search import TavilySearchResults
13
+ from langchain_community.document_loaders import WikipediaLoader
14
+ from langchain_community.document_loaders import ArxivLoader
15
+ from langchain_community.vectorstores import SupabaseVectorStore
16
+ from langchain_core.messages import SystemMessage, HumanMessage
17
+ from langchain_core.tools import tool
18
+ from langchain.tools.retriever import create_retriever_tool
19
+ from supabase.client import Client, create_client
20
+
21
+ load_dotenv()
22
+
23
+ @tool
24
+ def multiply(a: int, b: int) -> int:
25
+ """Multiply two numbers.
26
+ Args:
27
+ a: first int
28
+ b: second int
29
+ """
30
+ return a * b
31
+
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two numbers.
35
+ Args:
36
+ a: first int
37
+ b: second int
38
+ """
39
+ return a + b
40
+
41
+ @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """Subtract two numbers.
44
+ Args:
45
+ a: first int
46
+ b: second int
47
+ """
48
+ return a - b
49
+
50
+ @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
+ Args:
54
+ a: first int
55
+ b: second int
56
+ """
57
+ if b == 0:
58
+ raise ValueError("Cannot divide by zero.")
59
+ return a / b
60
+
61
+ @tool
62
+ def modulus(a: int, b: int) -> int:
63
+ """Get the modulus of two numbers.
64
+ Args:
65
+ a: first int
66
+ b: second int
67
+ """
68
+ return a % b
69
+
70
+ @tool
71
+ def wiki_search(query: str) -> str:
72
+ """Search Wikipedia for a query and return maximum 2 results.
73
+ Args:
74
+ query: The search query."""
75
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
76
+ formatted_search_docs = "\n\n---\n\n".join(
77
+ [
78
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
79
+ for doc in search_docs
80
+ ])
81
+ return {"wiki_results": formatted_search_docs}
82
+
83
+ @tool
84
+ def web_search(query: str) -> str:
85
+ """Search Tavily for a query and return maximum 3 results.
86
+ Args:
87
+ query: The search query."""
88
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
89
+ formatted_search_docs = "\n\n---\n\n".join(
90
+ [
91
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
92
+ for doc in search_docs
93
+ ])
94
+ return {"web_results": formatted_search_docs}
95
+
96
+ @tool
97
+ def arvix_search(query: str) -> str:
98
+ """Search Arxiv for a query and return maximum 3 result.
99
+ Args:
100
+ query: The search query."""
101
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
102
+ formatted_search_docs = "\n\n---\n\n".join(
103
+ [
104
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
105
+ for doc in search_docs
106
+ ])
107
+ return {"arvix_results": formatted_search_docs}
108
+
109
+
110
+
111
+ # load the system prompt from the file
112
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
113
+ system_prompt = f.read()
114
+
115
+ # System message
116
+ sys_msg = SystemMessage(content=system_prompt)
117
+
118
+ # build a retriever
119
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
120
+ supabase: Client = create_client(
121
+ os.environ.get("SUPABASE_URL"),
122
+ os.environ.get("SUPABASE_SERVICE_KEY"))
123
+ vector_store = SupabaseVectorStore(
124
+ client=supabase,
125
+ embedding= embeddings,
126
+ table_name="documents",
127
+ query_name="match_documents_langchain",
128
+ )
129
+ create_retriever_tool = create_retriever_tool(
130
+ retriever=vector_store.as_retriever(),
131
+ name="Question Search",
132
+ description="A tool to retrieve similar questions from a vector store.",
133
+ )
134
+
135
+
136
+
137
+ tools = [
138
+ multiply,
139
+ add,
140
+ subtract,
141
+ divide,
142
+ modulus,
143
+ wiki_search,
144
+ web_search,
145
+ arvix_search,
146
+ ]
147
+
148
+ # Build graph function
149
+ def build_graph(provider: str = "groq"):
150
+ """Build the graph"""
151
+ # Load environment variables from .env file
152
+ if provider == "google":
153
+ # Google Gemini
154
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
155
+ elif provider == "groq":
156
+ # Groq https://console.groq.com/docs/models
157
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
158
+ elif provider == "openai":
159
+ # OpenAI
160
+ llm = ChatOpenAI(model="gpt-4", temperature=0)
161
+ elif provider == "huggingface":
162
+ llm = ChatHuggingFace(
163
+ llm=HuggingFaceEndpoint(
164
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
165
+ temperature=0,
166
+ ),
167
+ )
168
+ else:
169
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
170
+ # Bind tools to LLM
171
+ llm_with_tools = llm.bind_tools(tools)
172
+
173
+ # Node
174
+ def assistant(state: MessagesState):
175
+ """Assistant node"""
176
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
177
+
178
+ def retriever(state: MessagesState):
179
+ """Retriever node"""
180
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
181
+ example_msg = HumanMessage(
182
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
183
+ )
184
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
185
+
186
+ builder = StateGraph(MessagesState)
187
+ builder.add_node("retriever", retriever)
188
+ builder.add_node("assistant", assistant)
189
+ builder.add_node("tools", ToolNode(tools))
190
+ builder.add_edge(START, "retriever")
191
+ builder.add_edge("retriever", "assistant")
192
+ builder.add_conditional_edges(
193
+ "assistant",
194
+ tools_condition,
195
+ )
196
+ builder.add_edge("tools", "assistant")
197
+
198
+ # Compile graph
199
+ return builder.compile()