vtony commited on
Commit
db7e060
·
verified ·
1 Parent(s): c782d1f

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +150 -0
agent.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0
  """Assistant node"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_community.vectorstores import SupabaseVectorStore
11
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
12
+ from langchain_core.tools import tool
13
+ from langchain.tools.retriever import create_retriever_tool
14
+ from supabase.client import Client, create_client
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+
19
+ # --- Math Tools ---
20
+ @tool
21
+ def multiply(a: int, b: int) -> int:
22
+ """Multiply two integers."""
23
+ return a * b
24
+
25
+ @tool
26
+ def add(a: int, b: int) -> int:
27
+ """Add two integers."""
28
+ return a + b
29
+
30
+ @tool
31
+ def subtract(a: int, b: int) -> int:
32
+ """Subtract b from a."""
33
+ return a - b
34
+
35
+ @tool
36
+ def divide(a: int, b: int) -> float:
37
+ """Divide a by b, error on zero."""
38
+ if b == 0:
39
+ raise ValueError("Cannot divide by zero.")
40
+ return a / b
41
+
42
+ @tool
43
+ def modulus(a: int, b: int) -> int:
44
+ """Compute a mod b."""
45
+ return a % b
46
+
47
+ # --- Browser Tools ---
48
+ @tool
49
+ def wiki_search(query: str) -> dict:
50
+ """Search Wikipedia and return up to 2 documents."""
51
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
+ results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
53
+ return {"wiki_results": "\n---\n".join(results)}
54
+
55
+ @tool
56
+ def web_search(query: str) -> dict:
57
+ """Search Tavily and return up to 3 results."""
58
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
59
+ results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
60
+ return {"web_results": "\n---\n".join(results)}
61
+
62
+ @tool
63
+ def arxiv_search(query: str) -> dict:
64
+ """Search Arxiv and return up to 3 docs."""
65
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
66
+ results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content[:1000]}" for d in docs]
67
+ return {"arxiv_results": "\n---\n".join(results)}
68
+
69
+ # --- Load system prompt ---
70
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
71
+ system_prompt = f.read()
72
+
73
+ # --- System message ---
74
+ sys_msg = SystemMessage(content=system_prompt)
75
+
76
+ # --- Retriever Tool ---
77
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
78
+ supabase = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
79
+
80
+ vector_store = SupabaseVectorStore(
81
+ client=supabase,
82
+ embedding=embeddings, table_name="documents",
83
+ query_name="match_documents_langchain")
84
+
85
+ retriever_tool = create_retriever_tool(
86
+ retriever=vector_store.as_retriever(
87
+
88
+ search_type="similarity",
89
+ search_kwargs={"k": 5}
90
+ ),
91
+ name="Question Search",
92
+ description="A tool to retrieve similar questions from the vector store."
93
+ )
94
+
95
+ tools = [
96
+ multiply,
97
+ add,
98
+ subtract,
99
+ divide,
100
+ modulus,
101
+ wiki_search,
102
+ web_search,
103
+ arxiv_search,
104
+ ]
105
+
106
+ # --- Graph Builder ---
107
+ def build_graph(provider: str = "huggingface"):
108
+ llm = ChatHuggingFace(
109
+ llm=HuggingFaceEndpoint(
110
+ repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
111
+ ),
112
+ )
113
+
114
+ # Bind tools to LLM
115
+ llm_with_tools = llm.bind_tools(tools)
116
+
117
+ # Define no def assistant(state: MessagesState):
118
  """Assistant node"""
119
+
120
+ return {"messages [ [llm_with_tools.invoke(state["messages"])]}se]}
121
+
122
+
123
+ # Retriever returns AIMessage def retriever(state: MessagesState):
124
+ """Retriever node"""
125
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
126
+ print('Similar questions:')
127
+ print(similar_question)
128
+ if len(similar_question) > 0:
129
+ example_msg = HumanMessage(
130
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
131
+ ntent}]}
132
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
133
+ return {"messages": [sys_msg] + state["m
134
+ # Add nodesessages"]}
135
+
136
+ builder = StateGraph(MessagesState)
137
+ builder.add_node("retriever", retriever)
138
+ builder.add_node("assistant", assistant)
139
+ builder.add_node("tools",
140
+
141
+ # Add edgesToolNode(tools))
142
+ builder.add_edge(START, "retriever")
143
+ builder.add_edge("retriever", "assistant")
144
+ builder.add_conditional_edges(
145
+ "assistant",
146
+ tools_condition,
147
+ )
148
+ builder.add_edge("tools", "assistant")ever")
149
+
150
+ # Compile graph
151
+ return builder.compile()