import os from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition from langgraph.prebuilt import ToolNode from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings from langchain_community.vectorstores import SupabaseVectorStore from langchain_core.messages import HumanMessage from langchain.tools.retriever import create_retriever_tool from supabase.client import Client, create_client from utils import load_prompt from tools import calculator, duck_web_search, wiki_search, arxiv_search load_dotenv() # Create retriever embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") # dim=768 supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")) vector_store = SupabaseVectorStore( client=supabase, embedding= embeddings, table_name="gaia_documents", query_name="match_documents_langchain", ) retriever = create_retriever_tool( retriever=vector_store.as_retriever(), name="ModernBERT Retriever", description="A retriever of similar questions from a vector store.", ) tools = [calculator, duck_web_search, wiki_search, arxiv_search] model_id = "Qwen/Qwen3-32B" llm = HuggingFaceEndpoint( repo_id=model_id, temperature=0, repetition_penalty=1.03, provider="auto", huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY") ) agent = ChatHuggingFace(llm=llm) agent_with_tools = agent.bind_tools(tools) def retriever_node(state: MessagesState): """RAG node""" similar_question = vector_store.similarity_search(state["messages"][0].content) response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")] return {"messages": response} def processor_node(state: MessagesState): system_prompt = load_prompt("prompt.yaml") messages = state.get("messages", []) response = [agent_with_tools.invoke([system_prompt] + messages)] """Agent node that answers questions""" return {"messages": response} def agent_graph(): workflow = StateGraph(MessagesState) ## Add nodes workflow.add_node("retriever_node", retriever_node) workflow.add_node("processor_node", processor_node) workflow.add_node("tools", ToolNode(tools)) ## Add edges workflow.add_edge(START, "retriever_node") workflow.add_edge("retriever_node", "processor_node") workflow.add_conditional_edges("processor_node", tools_condition) workflow.add_edge("tools", "processor_node") # Compile graph graph = workflow.compile() return graph