Irfshaikh's picture
Update agent.py
5f5d00b verified
raw
history blame
6.63 kB
import os
from dotenv import load_dotenv
from supabase.client import create_client
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import (
ChatHuggingFace,
HuggingFaceEndpoint,
HuggingFaceEmbeddings,
)
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.tools.retriever import create_retriever_tool
load_dotenv()
def load_system_prompt(path: str = "system_prompt.txt") -> SystemMessage:
"""
Load system prompt from a file, fallback to a default if missing.
Args:
path: File path to the system prompt.
Returns:
SystemMessage containing the loaded or default prompt.
"""
try:
with open(path, encoding="utf-8") as f:
content = f.read()
except FileNotFoundError:
content = "You are a helpful assistant."
return SystemMessage(content=content)
def math_tool(func):
"""
Wrap a Python function as a LangChain tool.
Args:
func: Callable to wrap.
Returns:
A LangChain tool.
"""
return tool(func)
@math_tool
def add(a: int, b: int) -> int:
"""Return a + b."""
return a + b
@math_tool
def subtract(a: int, b: int) -> int:
"""Return a - b."""
return a - b
@math_tool
def multiply(a: int, b: int) -> int:
"""Return a * b."""
return a * b
@math_tool
def divide(a: int, b: int) -> float:
"""
Return a / b.
Raises:
ValueError: If b is zero.
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@math_tool
def modulus(a: int, b: int) -> int:
"""Return a % b."""
return a % b
def format_docs(docs, key: str, max_chars: int = None) -> dict:
"""
Convert document list into labeled XML-style chunks.
Args:
docs: Iterable of Document objects.
key: Dict key for formatted results.
max_chars: Optionally truncate content.
Returns:
{key: formatted_string}
"""
entries = []
for d in docs:
content = d.page_content if max_chars is None else d.page_content[:max_chars]
entries.append(
f'<Document source="{d.metadata.get("source","")}" page="{d.metadata.get("page","")}">\n'
f"{content}\n</Document>"
)
return {key: "\n\n---\n\n".join(entries)}
@tool
def wiki_search(query: str) -> dict:
"""Search Wikipedia (2 docs) and format results."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
return format_docs(docs, "wiki_results")
@tool
def web_search(query: str) -> dict:
"""Search the web via Tavily (3 docs) and format results."""
docs = TavilySearchResults(max_results=3).invoke(query=query)
return format_docs(docs, "web_results")
@tool
def arxiv_search(query: str) -> dict:
"""Search ArXiv (3 docs) and format results (truncate to 1k chars)."""
docs = ArxivLoader(query=query, load_max_docs=3).load()
return format_docs(docs, "arxiv_results", max_chars=1000)
def build_vector_retriever():
"""
Create and return a Supabase-based vector retriever.
Returns:
Retriever for semantic similarity queries.
"""
embed = HuggingFaceEmbeddings("sentence-transformers/all-mpnet-base-v2")
supa = create_client(
os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")
)
store = SupabaseVectorStore(
client=supa,
embedding=embed,
table_name="documents",
query_name="match_documents_langchain",
)
return store.as_retriever()
def get_llm(provider: str = "google"):
"""
Factory to select and return an LLM client.
Args:
provider: One of "google", "groq", "huggingface".
Returns:
Configured LLM client.
Raises:
ValueError: On unsupported provider.
"""
if provider == "google":
return ChatGoogleGenerativeAI("gemini-2.0-flash", temperature=0)
if provider == "groq":
return ChatGroq("qwen-qwq-32b", temperature=0)
if provider == "huggingface":
return ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/"
"Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0,
)
)
raise ValueError(f"Unsupported provider: {provider}")
def build_graph(provider: str = "google"):
"""
Build and compile a StateGraph for retrieval + LLM responses.
Args:
provider: LLM provider key.
Returns:
A compiled StateGraph.
"""
sys_msg = load_system_prompt()
retriever = build_vector_retriever()
question_tool = create_retriever_tool(
retriever=retriever,
name="Question Search",
description="Retrieve similar Q&A from vector store.",
)
tools = [
add,
subtract,
multiply,
divide,
modulus,
wiki_search,
web_search,
arxiv_search,
question_tool,
]
llm = get_llm(provider).bind_tools(tools)
def retriever_node(state: MessagesState) -> dict:
"""
Node: retrieve most relevant doc and extract its answer.
"""
query = state["messages"][-1].content
doc = retriever.similarity_search(query, k=1)[0]
text = doc.page_content
ans = text.split("Final answer :")[-1].strip() if "Final answer :" in text else text
return {"messages": [AIMessage(content=ans)]}
def assistant_node(state: MessagesState) -> dict:
"""
Node: call LLM with system prompt + history.
"""
msgs = [sys_msg] + state["messages"]
resp = llm.invoke({"messages": msgs})
return {"messages": [resp]}
graph = StateGraph(MessagesState)
graph.add_node("retriever", retriever_node)
graph.add_node("assistant", assistant_node)
graph.add_node("tools", ToolNode(tools))
graph.add_edge(START, "retriever")
graph.add_edge("retriever", "assistant")
graph.add_conditional_edges("assistant", tools_condition)
graph.add_edge("tools", "assistant")
graph.set_entry_point("retriever")
graph.set_finish_point("assistant")
return graph.compile()