Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition | |
from langgraph.prebuilt import ToolNode | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain_community.vectorstores import SupabaseVectorStore | |
from langchain_core.messages import HumanMessage | |
from langchain.tools.retriever import create_retriever_tool | |
from supabase.client import Client, create_client | |
from utils import load_prompt | |
from tools import calculator, duck_web_search, wiki_search, arxiv_search | |
load_dotenv() | |
# Create retriever | |
embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") # dim=768 | |
supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")) | |
vector_store = SupabaseVectorStore( | |
client=supabase, | |
embedding= embeddings, | |
table_name="gaia_documents", | |
query_name="match_documents_langchain", | |
) | |
retriever = create_retriever_tool( | |
retriever=vector_store.as_retriever(), | |
name="ModernBERT Retriever", | |
description="A retriever of similar questions from a vector store.", | |
) | |
tools = [calculator, duck_web_search, wiki_search, arxiv_search] | |
model_id = "Qwen/Qwen3-32B" | |
llm = HuggingFaceEndpoint( | |
repo_id=model_id, | |
temperature=0, | |
repetition_penalty=1.03, | |
provider="auto", | |
huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY") | |
) | |
agent = ChatHuggingFace(llm=llm) | |
agent_with_tools = agent.bind_tools(tools) | |
def retriever_node(state: MessagesState): | |
"""RAG node""" | |
similar_question = vector_store.similarity_search(state["messages"][0].content) | |
response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")] | |
return {"messages": response} | |
def processor_node(state: MessagesState): | |
system_prompt = load_prompt("prompt.yaml") | |
messages = state.get("messages", []) | |
response = [agent_with_tools.invoke([system_prompt] + messages)] | |
"""Agent node that answers questions""" | |
return {"messages": response} | |
def agent_graph(): | |
workflow = StateGraph(MessagesState) | |
## Add nodes | |
workflow.add_node("retriever_node", retriever_node) | |
workflow.add_node("processor_node", processor_node) | |
workflow.add_node("tools", ToolNode(tools)) | |
## Add edges | |
workflow.add_edge(START, "retriever_node") | |
workflow.add_edge("retriever_node", "processor_node") | |
workflow.add_conditional_edges("processor_node", tools_condition) | |
workflow.add_edge("tools", "processor_node") | |
# Compile graph | |
graph = workflow.compile() | |
return graph |