martinthetechie commited on
Commit
51ee6d0
·
verified ·
1 Parent(s): 8cde32a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +89 -78
agent.py CHANGED
@@ -1,23 +1,28 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
 
 
 
18
 
19
  load_dotenv()
20
 
 
 
 
 
21
  @tool
22
  def multiply(a: int, b: int) -> int:
23
  """Multiply two numbers.
@@ -111,80 +116,101 @@ def arvix_search(query: str) -> str:
111
  ])
112
  return {"arvix_results": formatted_search_docs}
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
-
116
- # load the system prompt from the file
117
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
- system_prompt = f.read()
 
 
 
 
119
 
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
- # build a retriever
124
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
125
- supabase: Client = create_client(
126
- os.environ.get("SUPABASE_URL"),
127
- os.environ.get("SUPABASE_SERVICE_KEY"))
128
- vector_store = SupabaseVectorStore(
129
- client=supabase,
130
- embedding= embeddings,
131
- table_name="documents",
132
- query_name="match_documents_langchain",
 
 
 
 
 
 
 
 
 
 
 
 
133
  )
134
- create_retriever_tool = create_retriever_tool(
 
 
 
 
 
135
  retriever=vector_store.as_retriever(),
136
  name="Question Search",
137
  description="A tool to retrieve similar questions from a vector store.",
138
  )
139
 
140
-
141
-
142
  tools = [
143
- multiply,
144
- add,
145
- subtract,
146
- divide,
147
- modulus,
148
- wiki_search,
149
- web_search,
150
- arvix_search,
151
  ]
152
 
153
- # Build graph function
154
  def build_graph(provider: str = "groq"):
155
- """Build the graph"""
156
- # Load environment variables from .env file
157
- if provider == "google":
158
- # Google Gemini
159
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
- elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
- elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
- llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(
167
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
- temperature=0,
169
- ),
170
- )
171
- else:
172
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
  llm_with_tools = llm.bind_tools(tools)
175
 
176
- # Node
177
  def assistant(state: MessagesState):
178
- """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
-
181
  def retriever(state: MessagesState):
182
- """Retriever node"""
183
- similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- example_msg = HumanMessage(
185
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- )
187
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
 
189
  builder = StateGraph(MessagesState)
190
  builder.add_node("retriever", retriever)
@@ -192,22 +218,7 @@ def build_graph(provider: str = "groq"):
192
  builder.add_node("tools", ToolNode(tools))
193
  builder.add_edge(START, "retriever")
194
  builder.add_edge("retriever", "assistant")
195
- builder.add_conditional_edges(
196
- "assistant",
197
- tools_condition,
198
- )
199
  builder.add_edge("tools", "assistant")
200
 
201
- # Compile graph
202
- return builder.compile()
203
-
204
- # test
205
- if __name__ == "__main__":
206
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
207
- # Build the graph
208
- graph = build_graph(provider="groq")
209
- # Run the graph
210
- messages = [HumanMessage(content=question)]
211
- messages = graph.invoke({"messages": messages})
212
- for m in messages["messages"]:
213
- m.pretty_print()
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition, ToolNode
 
5
  from langchain_google_genai import ChatGoogleGenerativeAI
6
  from langchain_groq import ChatGroq
7
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
8
  from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_community.vectorstores import Chroma
11
+ from langchain_core.documents import Document
12
  from langchain_core.messages import SystemMessage, HumanMessage
13
  from langchain_core.tools import tool
14
  from langchain.tools.retriever import create_retriever_tool
15
+ import json
16
+ from langchain.vectorstores import Chroma
17
+ from langchain.embeddings import HuggingFaceEmbeddings
18
+ from langchain.schema import Document
19
 
20
  load_dotenv()
21
 
22
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
23
+ groq_api_key = os.getenv("GROQ_API_KEY")
24
+
25
+ # Tools
26
  @tool
27
  def multiply(a: int, b: int) -> int:
28
  """Multiply two numbers.
 
116
  ])
117
  return {"arvix_results": formatted_search_docs}
118
 
119
+ @tool
120
+ def similar_question_search(question: str) -> str:
121
+ """Search the vector database for similar questions and return the first results.
122
+
123
+ Args:
124
+ question: the question human provided."""
125
+ matched_docs = vector_store.similarity_search(query, 3)
126
+ formatted_search_docs = "\n\n---\n\n".join(
127
+ [
128
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
129
+ for doc in matched_docs
130
+ ])
131
+ return {"similar_questions": formatted_search_docs}
132
 
133
+ # Load system prompt
134
+ system_prompt = """
135
+ You are a helpful assistant tasked with answering questions using a set of tools.
136
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
137
+ FINAL ANSWER: [YOUR FINAL ANSWER].
138
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
139
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
140
+ """
141
 
142
  # System message
143
  sys_msg = SystemMessage(content=system_prompt)
144
 
145
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
146
+
147
+ with open('metadata.jsonl', 'r') as jsonl_file:
148
+ json_list = list(jsonl_file)
149
+
150
+ json_QA = []
151
+ for json_str in json_list:
152
+ json_data = json.loads(json_str)
153
+ json_QA.append(json_data)
154
+
155
+ documents = []
156
+ for sample in json_QA:
157
+ content = f"Question : {sample['Question']}\n\nFinal answer : {sample['Final answer']}"
158
+ metadata = {"source": sample["task_id"]}
159
+ documents.append(Document(page_content=content, metadata=metadata))
160
+
161
+ # Initialize vector store and add documents
162
+ vector_store = Chroma.from_documents(
163
+ documents=documents,
164
+ embedding=embeddings,
165
+ persist_directory="./chroma_db",
166
+ collection_name="my_collection"
167
  )
168
+ vector_store.persist()
169
+ print("Documents inserted:", vector_store._collection.count())
170
+
171
+
172
+ # Retriever tool (optional if you want to expose to agent)
173
+ retriever_tool = create_retriever_tool(
174
  retriever=vector_store.as_retriever(),
175
  name="Question Search",
176
  description="A tool to retrieve similar questions from a vector store.",
177
  )
178
 
179
+ # Tool list
 
180
  tools = [
181
+ multiply, add, subtract, divide, modulus,
182
+ wiki_search, web_search, arvix_search,
 
 
 
 
 
 
183
  ]
184
 
185
+ # Build graph
186
  def build_graph(provider: str = "groq"):
187
+ # if provider == "google":
188
+ # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
189
+ # elif provider == "groq":
190
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
191
+ # elif provider == "huggingface":
192
+ # llm = ChatHuggingFace(
193
+ # llm=HuggingFaceEndpoint(
194
+ # repo_id="mosaicml/mpt-30b",
195
+ # temperature=0,
196
+ # )
197
+ # )
198
+ # else:
199
+ # raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
200
+
201
+ # llm_with_tools = llm.bind_tools(tools)
202
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0,api_key=groq_api_key)
 
 
 
203
  llm_with_tools = llm.bind_tools(tools)
204
 
 
205
  def assistant(state: MessagesState):
 
206
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
207
+
208
  def retriever(state: MessagesState):
209
+ similar = vector_store.similarity_search(state["messages"][0].content)
210
+ if similar:
211
+ example_msg = HumanMessage(content=f"Here is a similar question:\n\n{similar[0].page_content}")
212
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
213
+ return {"messages": [sys_msg] + state["messages"]}
 
214
 
215
  builder = StateGraph(MessagesState)
216
  builder.add_node("retriever", retriever)
 
218
  builder.add_node("tools", ToolNode(tools))
219
  builder.add_edge(START, "retriever")
220
  builder.add_edge("retriever", "assistant")
221
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
222
  builder.add_edge("tools", "assistant")
223
 
224
+ return builder.compile()