wt002 commited on
Commit
dd8df2c
·
verified ·
1 Parent(s): f18e2ca

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +15 -5
agent.py CHANGED
@@ -1,5 +1,3 @@
1
- # agent.py
2
-
3
  import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
@@ -17,7 +15,6 @@ from langchain_core.tools import tool
17
  from langchain.tools.retriever import create_retriever_tool
18
  from supabase.client import Client, create_client
19
 
20
-
21
  load_dotenv()
22
 
23
  @tool
@@ -122,6 +119,7 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
122
  # System message
123
  sys_msg = SystemMessage(content=system_prompt)
124
 
 
125
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
  supabase: Client = create_client(
127
  os.environ.get("SUPABASE_URL"),
@@ -139,6 +137,7 @@ create_retriever_tool = create_retriever_tool(
139
  )
140
 
141
 
 
142
  tools = [
143
  multiply,
144
  add,
@@ -151,7 +150,7 @@ tools = [
151
  ]
152
 
153
  # Build graph function
154
- def build_graph(provider: str = "google"):
155
  """Build the graph"""
156
  # Load environment variables from .env file
157
  if provider == "google":
@@ -199,4 +198,15 @@ def build_graph(provider: str = "google"):
199
  builder.add_edge("tools", "assistant")
200
 
201
  # Compile graph
202
- return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  from dotenv import load_dotenv
3
  from langgraph.graph import START, StateGraph, MessagesState
 
15
  from langchain.tools.retriever import create_retriever_tool
16
  from supabase.client import Client, create_client
17
 
 
18
  load_dotenv()
19
 
20
  @tool
 
119
  # System message
120
  sys_msg = SystemMessage(content=system_prompt)
121
 
122
+ # build a retriever
123
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
124
  supabase: Client = create_client(
125
  os.environ.get("SUPABASE_URL"),
 
137
  )
138
 
139
 
140
+
141
  tools = [
142
  multiply,
143
  add,
 
150
  ]
151
 
152
  # Build graph function
153
+ def build_graph(provider: str = "groq"):
154
  """Build the graph"""
155
  # Load environment variables from .env file
156
  if provider == "google":
 
198
  builder.add_edge("tools", "assistant")
199
 
200
  # Compile graph
201
+ return builder.compile()
202
+
203
+ # test
204
+ if __name__ == "__main__":
205
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
206
+ # Build the graph
207
+ graph = build_graph(provider="groq")
208
+ # Run the graph
209
+ messages = [HumanMessage(content=question)]
210
+ messages = graph.invoke({"messages": messages})
211
+ for m in messages["messages"]:
212
+ m.pretty_print()