Irfshaikh commited on
Commit
5f5d00b
·
verified ·
1 Parent(s): 803089a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +163 -59
agent.py CHANGED
@@ -1,132 +1,236 @@
1
  import os
2
  from dotenv import load_dotenv
 
3
  from langgraph.graph import START, StateGraph, MessagesState
4
  from langgraph.prebuilt import ToolNode, tools_condition
5
  from langchain_core.tools import tool
6
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
 
 
 
 
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
  from langchain_community.vectorstores import SupabaseVectorStore
13
  from langchain.tools.retriever import create_retriever_tool
14
- from supabase.client import create_client
15
 
16
  load_dotenv()
17
 
18
- # --- System Prompt Loader ---
19
- def load_system_prompt(path="system_prompt.txt") -> SystemMessage:
 
 
 
 
 
 
 
 
 
20
  try:
21
  with open(path, encoding="utf-8") as f:
22
- return SystemMessage(content=f.read())
23
  except FileNotFoundError:
24
- return SystemMessage(content="You are a helpful assistant.")
 
25
 
26
- sys_msg = load_system_prompt()
27
 
28
- # --- Math Tools Factory ---
29
- def math_tool(fn):
30
- return tool(fn)
 
 
 
 
 
 
 
 
 
31
 
32
  @math_tool
33
- def add(a: int, b: int) -> int: return a + b
 
 
 
 
34
  @math_tool
35
- def subtract(a: int, b: int) -> int: return a - b
 
 
 
 
36
  @math_tool
37
- def multiply(a: int, b: int) -> int: return a * b
 
 
 
 
38
  @math_tool
39
  def divide(a: int, b: int) -> float:
40
- if b == 0: raise ValueError("Cannot divide by zero.")
 
 
 
 
 
 
 
41
  return a / b
42
 
 
43
  @math_tool
44
- def modulus(a: int, b: int) -> int: return a % b
 
 
 
45
 
46
- # --- Document Formatting Helper ---
47
  def format_docs(docs, key: str, max_chars: int = None) -> dict:
48
- content = "\n\n---\n\n".join(
49
- f'<Document source="{d.metadata.get("source","")}" page="{d.metadata.get("page","")}" />\n'
50
- f'{d.page_content[:max_chars] if max_chars else d.page_content}\n</Document>'
51
- for d in docs
52
- )
53
- return {key: content}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # --- Info Tools ---
56
  @tool
57
  def wiki_search(query: str) -> dict:
 
58
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
59
  return format_docs(docs, "wiki_results")
60
 
 
61
  @tool
62
  def web_search(query: str) -> dict:
 
63
  docs = TavilySearchResults(max_results=3).invoke(query=query)
64
  return format_docs(docs, "web_results")
65
 
 
66
  @tool
67
- def arvix_search(query: str) -> dict:
 
68
  docs = ArxivLoader(query=query, load_max_docs=3).load()
69
- return format_docs(docs, "arvix_results", max_chars=1000)
 
70
 
71
- # --- Vector Retriever Setup ---
72
  def build_vector_retriever():
73
- embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
74
- supa = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
75
- vs = SupabaseVectorStore(
 
 
 
 
 
 
 
 
76
  client=supa,
77
- embedding=embed_model,
78
  table_name="documents",
79
- query_name="match_documents_langchain"
80
  )
81
- return vs.as_retriever()
 
 
 
 
 
 
 
 
82
 
83
- # --- LLM Factory ---
84
- def get_llm(provider: str):
 
 
 
 
85
  if provider == "google":
86
- return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
87
  if provider == "groq":
88
- return ChatGroq(model="qwen-qwq-32b", temperature=0)
89
  if provider == "huggingface":
90
- return ChatHuggingFace(llm=HuggingFaceEndpoint(
91
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
92
- temperature=0))
 
 
 
 
93
  raise ValueError(f"Unsupported provider: {provider}")
94
 
95
- # --- Build Graph ---
96
  def build_graph(provider: str = "google"):
97
- # tools list
 
 
 
 
 
 
 
 
 
98
  retriever = build_vector_retriever()
99
  question_tool = create_retriever_tool(
100
  retriever=retriever,
101
  name="Question Search",
102
- description="Retrieve similar Q&A from vector store"
103
  )
 
104
  tools = [
105
- add, subtract, multiply, divide, modulus,
106
- wiki_search, web_search, arvix_search,
107
- question_tool
 
 
 
 
 
 
108
  ]
109
-
110
- # LLM w/ tools
111
  llm = get_llm(provider).bind_tools(tools)
112
 
113
- # Nodes
114
- def assistant(state: MessagesState):
115
- msgs = [sys_msg] + state["messages"]
116
- resp = llm.invoke({"messages": msgs})
117
- return {"messages": [resp]}
118
-
119
- def retriever_node(state: MessagesState):
120
  query = state["messages"][-1].content
121
  doc = retriever.similarity_search(query, k=1)[0]
122
  text = doc.page_content
123
- answer = text.split("Final answer :")[-1].strip() if "Final answer :" in text else text
124
- return {"messages": [AIMessage(content=answer)]}
 
 
 
 
 
 
 
 
125
 
126
- # Graph assembly
127
  graph = StateGraph(MessagesState)
128
  graph.add_node("retriever", retriever_node)
129
- graph.add_node("assistant", assistant)
130
  graph.add_node("tools", ToolNode(tools))
131
  graph.add_edge(START, "retriever")
132
  graph.add_edge("retriever", "assistant")
@@ -135,4 +239,4 @@ def build_graph(provider: str = "google"):
135
  graph.set_entry_point("retriever")
136
  graph.set_finish_point("assistant")
137
 
138
- return graph.compile()
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from supabase.client import create_client
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import ToolNode, tools_condition
6
  from langchain_core.tools import tool
7
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
+ from langchain_huggingface import (
11
+ ChatHuggingFace,
12
+ HuggingFaceEndpoint,
13
+ HuggingFaceEmbeddings,
14
+ )
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
17
  from langchain_community.vectorstores import SupabaseVectorStore
18
  from langchain.tools.retriever import create_retriever_tool
 
19
 
20
  load_dotenv()
21
 
22
+
23
+ def load_system_prompt(path: str = "system_prompt.txt") -> SystemMessage:
24
+ """
25
+ Load system prompt from a file, fallback to a default if missing.
26
+
27
+ Args:
28
+ path: File path to the system prompt.
29
+
30
+ Returns:
31
+ SystemMessage containing the loaded or default prompt.
32
+ """
33
  try:
34
  with open(path, encoding="utf-8") as f:
35
+ content = f.read()
36
  except FileNotFoundError:
37
+ content = "You are a helpful assistant."
38
+ return SystemMessage(content=content)
39
 
 
40
 
41
+ def math_tool(func):
42
+ """
43
+ Wrap a Python function as a LangChain tool.
44
+
45
+ Args:
46
+ func: Callable to wrap.
47
+
48
+ Returns:
49
+ A LangChain tool.
50
+ """
51
+ return tool(func)
52
+
53
 
54
  @math_tool
55
+ def add(a: int, b: int) -> int:
56
+ """Return a + b."""
57
+ return a + b
58
+
59
+
60
  @math_tool
61
+ def subtract(a: int, b: int) -> int:
62
+ """Return a - b."""
63
+ return a - b
64
+
65
+
66
  @math_tool
67
+ def multiply(a: int, b: int) -> int:
68
+ """Return a * b."""
69
+ return a * b
70
+
71
+
72
  @math_tool
73
  def divide(a: int, b: int) -> float:
74
+ """
75
+ Return a / b.
76
+
77
+ Raises:
78
+ ValueError: If b is zero.
79
+ """
80
+ if b == 0:
81
+ raise ValueError("Cannot divide by zero.")
82
  return a / b
83
 
84
+
85
  @math_tool
86
+ def modulus(a: int, b: int) -> int:
87
+ """Return a % b."""
88
+ return a % b
89
+
90
 
 
91
  def format_docs(docs, key: str, max_chars: int = None) -> dict:
92
+ """
93
+ Convert document list into labeled XML-style chunks.
94
+
95
+ Args:
96
+ docs: Iterable of Document objects.
97
+ key: Dict key for formatted results.
98
+ max_chars: Optionally truncate content.
99
+
100
+ Returns:
101
+ {key: formatted_string}
102
+ """
103
+ entries = []
104
+ for d in docs:
105
+ content = d.page_content if max_chars is None else d.page_content[:max_chars]
106
+ entries.append(
107
+ f'<Document source="{d.metadata.get("source","")}" page="{d.metadata.get("page","")}">\n'
108
+ f"{content}\n</Document>"
109
+ )
110
+ return {key: "\n\n---\n\n".join(entries)}
111
+
112
 
 
113
  @tool
114
  def wiki_search(query: str) -> dict:
115
+ """Search Wikipedia (2 docs) and format results."""
116
  docs = WikipediaLoader(query=query, load_max_docs=2).load()
117
  return format_docs(docs, "wiki_results")
118
 
119
+
120
  @tool
121
  def web_search(query: str) -> dict:
122
+ """Search the web via Tavily (3 docs) and format results."""
123
  docs = TavilySearchResults(max_results=3).invoke(query=query)
124
  return format_docs(docs, "web_results")
125
 
126
+
127
  @tool
128
+ def arxiv_search(query: str) -> dict:
129
+ """Search ArXiv (3 docs) and format results (truncate to 1k chars)."""
130
  docs = ArxivLoader(query=query, load_max_docs=3).load()
131
+ return format_docs(docs, "arxiv_results", max_chars=1000)
132
+
133
 
 
134
  def build_vector_retriever():
135
+ """
136
+ Create and return a Supabase-based vector retriever.
137
+
138
+ Returns:
139
+ Retriever for semantic similarity queries.
140
+ """
141
+ embed = HuggingFaceEmbeddings("sentence-transformers/all-mpnet-base-v2")
142
+ supa = create_client(
143
+ os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")
144
+ )
145
+ store = SupabaseVectorStore(
146
  client=supa,
147
+ embedding=embed,
148
  table_name="documents",
149
+ query_name="match_documents_langchain",
150
  )
151
+ return store.as_retriever()
152
+
153
+
154
+ def get_llm(provider: str = "google"):
155
+ """
156
+ Factory to select and return an LLM client.
157
+
158
+ Args:
159
+ provider: One of "google", "groq", "huggingface".
160
 
161
+ Returns:
162
+ Configured LLM client.
163
+
164
+ Raises:
165
+ ValueError: On unsupported provider.
166
+ """
167
  if provider == "google":
168
+ return ChatGoogleGenerativeAI("gemini-2.0-flash", temperature=0)
169
  if provider == "groq":
170
+ return ChatGroq("qwen-qwq-32b", temperature=0)
171
  if provider == "huggingface":
172
+ return ChatHuggingFace(
173
+ llm=HuggingFaceEndpoint(
174
+ url="https://api-inference.huggingface.co/models/"
175
+ "Meta-DeepLearning/llama-2-7b-chat-hf",
176
+ temperature=0,
177
+ )
178
+ )
179
  raise ValueError(f"Unsupported provider: {provider}")
180
 
181
+
182
  def build_graph(provider: str = "google"):
183
+ """
184
+ Build and compile a StateGraph for retrieval + LLM responses.
185
+
186
+ Args:
187
+ provider: LLM provider key.
188
+
189
+ Returns:
190
+ A compiled StateGraph.
191
+ """
192
+ sys_msg = load_system_prompt()
193
  retriever = build_vector_retriever()
194
  question_tool = create_retriever_tool(
195
  retriever=retriever,
196
  name="Question Search",
197
+ description="Retrieve similar Q&A from vector store.",
198
  )
199
+
200
  tools = [
201
+ add,
202
+ subtract,
203
+ multiply,
204
+ divide,
205
+ modulus,
206
+ wiki_search,
207
+ web_search,
208
+ arxiv_search,
209
+ question_tool,
210
  ]
 
 
211
  llm = get_llm(provider).bind_tools(tools)
212
 
213
+ def retriever_node(state: MessagesState) -> dict:
214
+ """
215
+ Node: retrieve most relevant doc and extract its answer.
216
+ """
 
 
 
217
  query = state["messages"][-1].content
218
  doc = retriever.similarity_search(query, k=1)[0]
219
  text = doc.page_content
220
+ ans = text.split("Final answer :")[-1].strip() if "Final answer :" in text else text
221
+ return {"messages": [AIMessage(content=ans)]}
222
+
223
+ def assistant_node(state: MessagesState) -> dict:
224
+ """
225
+ Node: call LLM with system prompt + history.
226
+ """
227
+ msgs = [sys_msg] + state["messages"]
228
+ resp = llm.invoke({"messages": msgs})
229
+ return {"messages": [resp]}
230
 
 
231
  graph = StateGraph(MessagesState)
232
  graph.add_node("retriever", retriever_node)
233
+ graph.add_node("assistant", assistant_node)
234
  graph.add_node("tools", ToolNode(tools))
235
  graph.add_edge(START, "retriever")
236
  graph.add_edge("retriever", "assistant")
 
239
  graph.set_entry_point("retriever")
240
  graph.set_finish_point("assistant")
241
 
242
+ return graph.compile()