0xrushi
test
de49b6b
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain_community.retrievers import BM25Retriever
from smolagents import DuckDuckGoSearchTool
from smolagents import Tool
from langchain.vectorstores import FAISS
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
# Load environment variables
load_dotenv()
class QuestionRetrieverTool(Tool):
name="Question Search",
description="Retrieve similar questions from the vector store."
inputs = {
"query": {
"type": "string",
"description": "The question you want relation about."
}
}
output_type = "string"
def __init__(self, docs):
self.is_initialized = False
self.retriever = BM25Retriever.from_documents(docs)
def forward(self, query: str):
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join([doc.page_content for doc in results[:3]])
else:
return "No matching Questions found."
@tool
def wiki_search(query: str) -> dict:
"""Search Wikipedia and return up to 2 documents."""
docs = WikipediaLoader(query=query, load_max_docs=2).load()
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
return {"wiki_results": "\n---\n".join(results)}
@tool
def web_search(query: str) -> dict:
"""Search DDG and return up to 3 results."""
docs = DuckDuckGoSearchTool(max_results=3).invoke(query=query)
results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
return {"web_results": "\n---\n".join(results)}
# --- Load system prompt ---
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
sys_msg = SystemMessage(content=system_prompt)
# --- Retriever Tool ---
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
embedding_dim = 768 # for 'all-mpnet-base-v2'
empty_index = faiss.IndexFlatL2(embedding_dim)
docstore = InMemoryDocstore({})
vector_store = FAISS(embedding_function=embeddings, index=empty_index, docstore=docstore, index_to_docstore_id={})
retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="Retrieve similar questions from the vector store."
)
tools = [
wiki_search,
web_search,
retriever_tool,
]
# --- Graph Builder ---
def build_graph():
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
repo_id="meta-llama/Llama-2-7b-chat-hf",
temperature=0,
huggingfacehub_api_token=os.getenv("HF_TOKEN")
)
)
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Define nodes
def assistant(state: MessagesState):
return {"messages": [llm_with_tools.invoke(state["messages"])]}
# Retriever node returns AIMessage
def retriever(state: MessagesState):
query = state["messages"][-1].content
similar_docs = vector_store.similarity_search(query, k=1)
if similar_docs:
reference = similar_docs[0].page_content
context_msg = HumanMessage(content=f"Here is a similar question and answer for reference:\n\n{reference}")
else:
context_msg = HumanMessage(content="No relevant example found.")
return {
"messages": [sys_msg] + state["messages"] + [context_msg]
}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()