KPatelis's picture
fix langgraph workflow
8df4ab3
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import HumanMessage
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
from utils import load_prompt
from tools import calculator, duck_web_search, wiki_search, arxiv_search
load_dotenv()
# Create retriever
embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") # dim=768
supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
vector_store = SupabaseVectorStore(
client=supabase,
embedding= embeddings,
table_name="gaia_documents",
query_name="match_documents_langchain",
)
retriever = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="ModernBERT Retriever",
description="A retriever of similar questions from a vector store.",
)
tools = [calculator, duck_web_search, wiki_search, arxiv_search]
model_id = "Qwen/Qwen3-32B"
llm = HuggingFaceEndpoint(
repo_id=model_id,
temperature=0,
repetition_penalty=1.03,
provider="auto",
huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY")
)
agent = ChatHuggingFace(llm=llm)
agent_with_tools = agent.bind_tools(tools)
def retriever_node(state: MessagesState):
"""RAG node"""
similar_question = vector_store.similarity_search(state["messages"][0].content)
response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")]
return {"messages": response}
def processor_node(state: MessagesState):
system_prompt = load_prompt("prompt.yaml")
messages = state.get("messages", [])
response = [agent_with_tools.invoke([system_prompt] + messages)]
"""Agent node that answers questions"""
return {"messages": response}
def agent_graph():
workflow = StateGraph(MessagesState)
## Add nodes
workflow.add_node("retriever_node", retriever_node)
workflow.add_node("processor_node", processor_node)
workflow.add_node("tools", ToolNode(tools))
## Add edges
workflow.add_edge(START, "retriever_node")
workflow.add_edge("retriever_node", "processor_node")
workflow.add_conditional_edges("processor_node", tools_condition)
workflow.add_edge("tools", "processor_node")
# Compile graph
graph = workflow.compile()
return graph