Irfshaikh commited on
Commit
1a7cf31
·
verified ·
1 Parent(s): 5f5d00b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +132 -186
agent.py CHANGED
@@ -1,17 +1,14 @@
1
  import os
 
2
  from dotenv import load_dotenv
3
- from supabase.client import create_client
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import ToolNode, tools_condition
6
  from langchain_core.tools import tool
7
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
- from langchain_huggingface import (
11
- ChatHuggingFace,
12
- HuggingFaceEndpoint,
13
- HuggingFaceEmbeddings,
14
- )
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
17
  from langchain_community.vectorstores import SupabaseVectorStore
@@ -19,224 +16,173 @@ from langchain.tools.retriever import create_retriever_tool
19
 
20
  load_dotenv()
21
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- def load_system_prompt(path: str = "system_prompt.txt") -> SystemMessage:
24
- """
25
- Load system prompt from a file, fallback to a default if missing.
26
-
27
  Args:
28
- path: File path to the system prompt.
29
-
30
- Returns:
31
- SystemMessage containing the loaded or default prompt.
32
  """
33
- try:
34
- with open(path, encoding="utf-8") as f:
35
- content = f.read()
36
- except FileNotFoundError:
37
- content = "You are a helpful assistant."
38
- return SystemMessage(content=content)
39
-
40
-
41
- def math_tool(func):
42
- """
43
- Wrap a Python function as a LangChain tool.
44
 
 
 
 
 
45
  Args:
46
- func: Callable to wrap.
47
-
48
- Returns:
49
- A LangChain tool.
50
  """
51
- return tool(func)
52
-
53
-
54
- @math_tool
55
- def add(a: int, b: int) -> int:
56
- """Return a + b."""
57
  return a + b
58
 
59
-
60
- @math_tool
61
  def subtract(a: int, b: int) -> int:
62
- """Return a - b."""
63
- return a - b
64
-
65
-
66
- @math_tool
67
- def multiply(a: int, b: int) -> int:
68
- """Return a * b."""
69
- return a * b
70
-
71
-
72
- @math_tool
73
- def divide(a: int, b: int) -> float:
74
  """
75
- Return a / b.
76
 
77
- Raises:
78
- ValueError: If b is zero.
 
 
 
 
 
79
  """
80
  if b == 0:
81
  raise ValueError("Cannot divide by zero.")
82
  return a / b
83
 
84
-
85
- @math_tool
86
  def modulus(a: int, b: int) -> int:
87
- """Return a % b."""
88
- return a % b
89
-
90
-
91
- def format_docs(docs, key: str, max_chars: int = None) -> dict:
92
- """
93
- Convert document list into labeled XML-style chunks.
94
-
95
  Args:
96
- docs: Iterable of Document objects.
97
- key: Dict key for formatted results.
98
- max_chars: Optionally truncate content.
99
-
100
- Returns:
101
- {key: formatted_string}
102
  """
103
- entries = []
104
- for d in docs:
105
- content = d.page_content if max_chars is None else d.page_content[:max_chars]
106
- entries.append(
107
- f'<Document source="{d.metadata.get("source","")}" page="{d.metadata.get("page","")}">\n'
108
- f"{content}\n</Document>"
109
- )
110
- return {key: "\n\n---\n\n".join(entries)}
111
-
112
 
113
  @tool
114
- def wiki_search(query: str) -> dict:
115
- """Search Wikipedia (2 docs) and format results."""
 
 
 
116
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
117
- return format_docs(docs, "wiki_results")
118
-
119
 
120
  @tool
121
- def web_search(query: str) -> dict:
122
- """Search the web via Tavily (3 docs) and format results."""
 
 
 
123
  docs = TavilySearchResults(max_results=3).invoke(query=query)
124
- return format_docs(docs, "web_results")
125
-
126
 
127
  @tool
128
- def arxiv_search(query: str) -> dict:
129
- """Search ArXiv (3 docs) and format results (truncate to 1k chars)."""
 
 
 
130
  docs = ArxivLoader(query=query, load_max_docs=3).load()
131
- return format_docs(docs, "arxiv_results", max_chars=1000)
132
-
133
-
134
- def build_vector_retriever():
135
- """
136
- Create and return a Supabase-based vector retriever.
137
 
138
- Returns:
139
- Retriever for semantic similarity queries.
140
- """
141
- embed = HuggingFaceEmbeddings("sentence-transformers/all-mpnet-base-v2")
142
- supa = create_client(
143
- os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")
144
- )
145
- store = SupabaseVectorStore(
146
- client=supa,
147
- embedding=embed,
148
- table_name="documents",
149
- query_name="match_documents_langchain",
150
- )
151
- return store.as_retriever()
152
-
153
-
154
- def get_llm(provider: str = "google"):
155
- """
156
- Factory to select and return an LLM client.
157
 
158
- Args:
159
- provider: One of "google", "groq", "huggingface".
160
 
161
- Returns:
162
- Configured LLM client.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- Raises:
165
- ValueError: On unsupported provider.
166
- """
167
- if provider == "google":
168
- return ChatGoogleGenerativeAI("gemini-2.0-flash", temperature=0)
169
- if provider == "groq":
170
- return ChatGroq("qwen-qwq-32b", temperature=0)
171
- if provider == "huggingface":
172
- return ChatHuggingFace(
173
- llm=HuggingFaceEndpoint(
174
- url="https://api-inference.huggingface.co/models/"
175
- "Meta-DeepLearning/llama-2-7b-chat-hf",
176
- temperature=0,
177
- )
 
 
 
 
 
178
  )
179
- raise ValueError(f"Unsupported provider: {provider}")
180
-
181
 
182
- def build_graph(provider: str = "google"):
 
183
  """
184
- Build and compile a StateGraph for retrieval + LLM responses.
185
-
186
- Args:
187
- provider: LLM provider key.
188
-
189
- Returns:
190
- A compiled StateGraph.
191
  """
192
- sys_msg = load_system_prompt()
193
- retriever = build_vector_retriever()
194
- question_tool = create_retriever_tool(
195
- retriever=retriever,
196
- name="Question Search",
197
- description="Retrieve similar Q&A from vector store.",
198
- )
199
-
200
- tools = [
201
- add,
202
- subtract,
203
- multiply,
204
- divide,
205
- modulus,
206
- wiki_search,
207
- web_search,
208
- arxiv_search,
209
- question_tool,
210
- ]
211
  llm = get_llm(provider).bind_tools(tools)
212
 
213
- def retriever_node(state: MessagesState) -> dict:
214
- """
215
- Node: retrieve most relevant doc and extract its answer.
216
- """
 
217
  query = state["messages"][-1].content
218
- doc = retriever.similarity_search(query, k=1)[0]
219
- text = doc.page_content
220
- ans = text.split("Final answer :")[-1].strip() if "Final answer :" in text else text
221
- return {"messages": [AIMessage(content=ans)]}
222
-
223
- def assistant_node(state: MessagesState) -> dict:
224
- """
225
- Node: call LLM with system prompt + history.
226
- """
227
- msgs = [sys_msg] + state["messages"]
228
- resp = llm.invoke({"messages": msgs})
229
- return {"messages": [resp]}
230
 
231
  graph = StateGraph(MessagesState)
232
- graph.add_node("retriever", retriever_node)
233
- graph.add_node("assistant", assistant_node)
234
- graph.add_node("tools", ToolNode(tools))
235
- graph.add_edge(START, "retriever")
236
- graph.add_edge("retriever", "assistant")
237
- graph.add_conditional_edges("assistant", tools_condition)
238
- graph.add_edge("tools", "assistant")
239
  graph.set_entry_point("retriever")
240
- graph.set_finish_point("assistant")
241
-
242
- return graph.compile()
 
1
  import os
2
+ import functools
3
  from dotenv import load_dotenv
4
+ from supabase.client import create_client, Client
5
  from langgraph.graph import START, StateGraph, MessagesState
6
+ from langgraph.prebuilt import tools_condition, ToolNode
7
  from langchain_core.tools import tool
8
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_groq import ChatGroq
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
 
 
 
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
14
  from langchain_community.vectorstores import SupabaseVectorStore
 
16
 
17
  load_dotenv()
18
 
19
+ def _format_search_results(docs, label: str, truncate: int = None) -> dict:
20
+ """Helper to format document search results."""
21
+ entries = []
22
+ for d in docs:
23
+ content = d.page_content if truncate is None else d.page_content[:truncate]
24
+ entries.append(
25
+ f'<Document source="{d.metadata.get("source","")}" '
26
+ f'page="{d.metadata.get("page","")}"/>\n{content}\n</Document>'
27
+ )
28
+ return {label: "\n\n---\n\n".join(entries)}
29
 
30
+ @tool
31
+ def multiply(a: int, b: int) -> int:
32
+ """Multiply two numbers.
 
33
  Args:
34
+ a: first int
35
+ b: second int
 
 
36
  """
37
+ return a * b
 
 
 
 
 
 
 
 
 
 
38
 
39
+ @tool
40
+ def add(a: int, b: int) -> int:
41
+ """Add two numbers.
42
+
43
  Args:
44
+ a: first int
45
+ b: second int
 
 
46
  """
 
 
 
 
 
 
47
  return a + b
48
 
49
+ @tool
 
50
  def subtract(a: int, b: int) -> int:
51
+ """Subtract two numbers.
52
+
53
+ Args:
54
+ a: first int
55
+ b: second int
 
 
 
 
 
 
 
56
  """
57
+ return a - b
58
 
59
+ @tool
60
+ def divide(a: int, b: int) -> int:
61
+ """Divide two numbers.
62
+
63
+ Args:
64
+ a: first int
65
+ b: second int
66
  """
67
  if b == 0:
68
  raise ValueError("Cannot divide by zero.")
69
  return a / b
70
 
71
+ @tool
 
72
  def modulus(a: int, b: int) -> int:
73
+ """Get the modulus of two numbers.
74
+
 
 
 
 
 
 
75
  Args:
76
+ a: first int
77
+ b: second int
 
 
 
 
78
  """
79
+ return a % b
 
 
 
 
 
 
 
 
80
 
81
  @tool
82
+ def wiki_search(query: str) -> str:
83
+ """Search Wikipedia for a query and return maximum 2 results.
84
+
85
+ Args:
86
+ query: The search query."""
87
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
88
+ return _format_search_results(docs, "wiki_results")
 
89
 
90
  @tool
91
+ def web_search(query: str) -> str:
92
+ """Search Tavily for a query and return maximum 3 results.
93
+
94
+ Args:
95
+ query: The search query."""
96
  docs = TavilySearchResults(max_results=3).invoke(query=query)
97
+ return _format_search_results(docs, "web_results")
 
98
 
99
  @tool
100
+ def arvix_search(query: str) -> str:
101
+ """Search Arxiv for a query and return maximum 3 result.
102
+
103
+ Args:
104
+ query: The search query."""
105
  docs = ArxivLoader(query=query, load_max_docs=3).load()
106
+ return _format_search_results(docs, "arvix_results", truncate=1000)
 
 
 
 
 
107
 
108
+ # load the system prompt from the file
109
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
110
+ system_prompt = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # System message
113
+ sys_msg = SystemMessage(content=system_prompt)
114
 
115
+ # build a retriever once
116
+ _embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
117
+ _supabase: Client = create_client(
118
+ os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"]
119
+ )
120
+ _vector_store = SupabaseVectorStore(
121
+ client=_supabase,
122
+ embedding=_embeddings,
123
+ table_name="documents",
124
+ query_name="match_documents_langchain",
125
+ )
126
+ _retriever = _vector_store.as_retriever()
127
+ _question_search_tool = create_retriever_tool(
128
+ retriever=_retriever,
129
+ name="Question Search",
130
+ description="A tool to retrieve similar questions from a vector store.",
131
+ )
132
 
133
+ tools = [
134
+ multiply,
135
+ add,
136
+ subtract,
137
+ divide,
138
+ modulus,
139
+ wiki_search,
140
+ web_search,
141
+ arvix_search,
142
+ _question_search_tool,
143
+ ]
144
+
145
+ _LLM_PROVIDERS = {
146
+ "google": lambda: ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0),
147
+ "groq": lambda: ChatGroq(model="qwen-qwq-32b", temperature=0),
148
+ "huggingface": lambda: ChatHuggingFace(
149
+ llm=HuggingFaceEndpoint(
150
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
151
+ temperature=0,
152
  )
153
+ ),
154
+ }
155
 
156
+ @functools.lru_cache(maxsize=None)
157
+ def get_llm(provider: str):
158
  """
159
+ Retrieve and cache the LLM client for the given provider.
 
 
 
 
 
 
160
  """
161
+ try:
162
+ return _LLM_PROVIDERS[provider]()
163
+ except KeyError:
164
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
165
+
166
+ def build_graph(provider: str = "google"):
167
+ """Build the graph"""
 
 
 
 
 
 
 
 
 
 
 
 
168
  llm = get_llm(provider).bind_tools(tools)
169
 
170
+ def assistant(state: MessagesState):
171
+ """Assistant node"""
172
+ return {"messages": [llm.invoke(state["messages"])]}
173
+
174
+ def retriever(state: MessagesState):
175
  query = state["messages"][-1].content
176
+ doc = _retriever.similarity_search(query, k=1)[0]
177
+ content = doc.page_content
178
+ if "Final answer :" in content:
179
+ answer = content.split("Final answer :")[-1].strip()
180
+ else:
181
+ answer = content.strip()
182
+ return {"messages": [AIMessage(content=answer)]}
 
 
 
 
 
183
 
184
  graph = StateGraph(MessagesState)
185
+ graph.add_node("retriever", retriever)
 
 
 
 
 
 
186
  graph.set_entry_point("retriever")
187
+ graph.set_finish_point("retriever")
188
+ return graph.compile()