File size: 2,723 Bytes
b3f9415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8df4ab3
b3f9415
 
8df4ab3
 
 
b3f9415
 
8df4ab3
 
 
 
b3f9415
 
8df4ab3
b3f9415
8df4ab3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import HumanMessage
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client
from utils import load_prompt
from tools import calculator, duck_web_search, wiki_search, arxiv_search

load_dotenv()

# Create retriever
embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") #  dim=768

supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
vector_store = SupabaseVectorStore(
    client=supabase,
    embedding= embeddings,
    table_name="gaia_documents",
    query_name="match_documents_langchain",
)

retriever = create_retriever_tool(
    retriever=vector_store.as_retriever(),
    name="ModernBERT Retriever",
    description="A retriever of similar questions from a vector store.",
)

tools = [calculator, duck_web_search, wiki_search, arxiv_search]

model_id = "Qwen/Qwen3-32B"

llm = HuggingFaceEndpoint(
    repo_id=model_id,
    temperature=0,
    repetition_penalty=1.03,
    provider="auto",
    huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY")
)

agent = ChatHuggingFace(llm=llm)

agent_with_tools = agent.bind_tools(tools)

def retriever_node(state: MessagesState):
        """RAG node"""
        similar_question = vector_store.similarity_search(state["messages"][0].content)
        response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")]
        return {"messages": response}

def processor_node(state: MessagesState):
        
        system_prompt = load_prompt("prompt.yaml")

        messages = state.get("messages", [])
        response = [agent_with_tools.invoke([system_prompt] + messages)]
        """Agent node that answers questions"""
        return {"messages": response}

def agent_graph():
    workflow = StateGraph(MessagesState)

    ## Add nodes
    workflow.add_node("retriever_node", retriever_node)
    workflow.add_node("processor_node", processor_node)
    workflow.add_node("tools", ToolNode(tools))

    ## Add edges
    workflow.add_edge(START, "retriever_node")
    workflow.add_edge("retriever_node", "processor_node")
    workflow.add_conditional_edges("processor_node", tools_condition)
    workflow.add_edge("tools", "processor_node")

    # Compile graph
    graph = workflow.compile()

    return graph