Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
# agent.py
|
2 |
-
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
5 |
from langgraph.graph import START, StateGraph, MessagesState
|
@@ -17,7 +15,6 @@ from langchain_core.tools import tool
|
|
17 |
from langchain.tools.retriever import create_retriever_tool
|
18 |
from supabase.client import Client, create_client
|
19 |
|
20 |
-
|
21 |
load_dotenv()
|
22 |
|
23 |
@tool
|
@@ -122,6 +119,7 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
|
122 |
# System message
|
123 |
sys_msg = SystemMessage(content=system_prompt)
|
124 |
|
|
|
125 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
126 |
supabase: Client = create_client(
|
127 |
os.environ.get("SUPABASE_URL"),
|
@@ -139,6 +137,7 @@ create_retriever_tool = create_retriever_tool(
|
|
139 |
)
|
140 |
|
141 |
|
|
|
142 |
tools = [
|
143 |
multiply,
|
144 |
add,
|
@@ -151,7 +150,7 @@ tools = [
|
|
151 |
]
|
152 |
|
153 |
# Build graph function
|
154 |
-
def build_graph(provider: str = "
|
155 |
"""Build the graph"""
|
156 |
# Load environment variables from .env file
|
157 |
if provider == "google":
|
@@ -199,4 +198,15 @@ def build_graph(provider: str = "google"):
|
|
199 |
builder.add_edge("tools", "assistant")
|
200 |
|
201 |
# Compile graph
|
202 |
-
return builder.compile()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from dotenv import load_dotenv
|
3 |
from langgraph.graph import START, StateGraph, MessagesState
|
|
|
15 |
from langchain.tools.retriever import create_retriever_tool
|
16 |
from supabase.client import Client, create_client
|
17 |
|
|
|
18 |
load_dotenv()
|
19 |
|
20 |
@tool
|
|
|
119 |
# System message
|
120 |
sys_msg = SystemMessage(content=system_prompt)
|
121 |
|
122 |
+
# build a retriever
|
123 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
124 |
supabase: Client = create_client(
|
125 |
os.environ.get("SUPABASE_URL"),
|
|
|
137 |
)
|
138 |
|
139 |
|
140 |
+
|
141 |
tools = [
|
142 |
multiply,
|
143 |
add,
|
|
|
150 |
]
|
151 |
|
152 |
# Build graph function
|
153 |
+
def build_graph(provider: str = "groq"):
|
154 |
"""Build the graph"""
|
155 |
# Load environment variables from .env file
|
156 |
if provider == "google":
|
|
|
198 |
builder.add_edge("tools", "assistant")
|
199 |
|
200 |
# Compile graph
|
201 |
+
return builder.compile()
|
202 |
+
|
203 |
+
# test
|
204 |
+
if __name__ == "__main__":
|
205 |
+
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
|
206 |
+
# Build the graph
|
207 |
+
graph = build_graph(provider="groq")
|
208 |
+
# Run the graph
|
209 |
+
messages = [HumanMessage(content=question)]
|
210 |
+
messages = graph.invoke({"messages": messages})
|
211 |
+
for m in messages["messages"]:
|
212 |
+
m.pretty_print()
|