File size: 6,631 Bytes
588b982
 
5f5d00b
588b982
 
 
 
 
 
5f5d00b
 
 
 
 
588b982
 
 
 
 
 
 
5f5d00b
 
 
 
 
 
 
 
 
 
 
588b982
 
5f5d00b
588b982
5f5d00b
 
588b982
 
5f5d00b
 
 
 
 
 
 
 
 
 
 
 
588b982
 
5f5d00b
 
 
 
 
588b982
5f5d00b
 
 
 
 
588b982
5f5d00b
 
 
 
 
588b982
 
5f5d00b
 
 
 
 
 
 
 
588b982
 
5f5d00b
588b982
5f5d00b
 
 
 
588b982
 
5f5d00b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588b982
 
 
5f5d00b
588b982
 
 
5f5d00b
588b982
 
5f5d00b
588b982
 
 
5f5d00b
588b982
5f5d00b
 
588b982
5f5d00b
 
588b982
 
5f5d00b
 
 
 
 
 
 
 
 
 
 
588b982
5f5d00b
588b982
5f5d00b
588b982
5f5d00b
 
 
 
 
 
 
 
 
588b982
5f5d00b
 
 
 
 
 
588b982
5f5d00b
588b982
5f5d00b
588b982
5f5d00b
 
 
 
 
 
 
588b982
 
5f5d00b
588b982
5f5d00b
 
 
 
 
 
 
 
 
 
588b982
 
 
 
5f5d00b
588b982
5f5d00b
588b982
5f5d00b
 
 
 
 
 
 
 
 
588b982
 
 
5f5d00b
 
 
 
588b982
 
 
5f5d00b
 
 
 
 
 
 
 
 
 
588b982
 
 
5f5d00b
588b982
 
 
 
 
 
 
 
5f5d00b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
from dotenv import load_dotenv
from supabase.client import create_client
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.tools import tool
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import (
    ChatHuggingFace,
    HuggingFaceEndpoint,
    HuggingFaceEmbeddings,
)
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.tools.retriever import create_retriever_tool

load_dotenv()


def load_system_prompt(path: str = "system_prompt.txt") -> SystemMessage:
    """
    Load system prompt from a file, fallback to a default if missing.

    Args:
        path: File path to the system prompt.

    Returns:
        SystemMessage containing the loaded or default prompt.
    """
    try:
        with open(path, encoding="utf-8") as f:
            content = f.read()
    except FileNotFoundError:
        content = "You are a helpful assistant."
    return SystemMessage(content=content)


def math_tool(func):
    """
    Wrap a Python function as a LangChain tool.

    Args:
        func: Callable to wrap.

    Returns:
        A LangChain tool.
    """
    return tool(func)


@math_tool
def add(a: int, b: int) -> int:
    """Return a + b."""
    return a + b


@math_tool
def subtract(a: int, b: int) -> int:
    """Return a - b."""
    return a - b


@math_tool
def multiply(a: int, b: int) -> int:
    """Return a * b."""
    return a * b


@math_tool
def divide(a: int, b: int) -> float:
    """
    Return a / b.

    Raises:
        ValueError: If b is zero.
    """
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b


@math_tool
def modulus(a: int, b: int) -> int:
    """Return a % b."""
    return a % b


def format_docs(docs, key: str, max_chars: int = None) -> dict:
    """
    Convert document list into labeled XML-style chunks.

    Args:
        docs: Iterable of Document objects.
        key: Dict key for formatted results.
        max_chars: Optionally truncate content.

    Returns:
        {key: formatted_string}
    """
    entries = []
    for d in docs:
        content = d.page_content if max_chars is None else d.page_content[:max_chars]
        entries.append(
            f'<Document source="{d.metadata.get("source","")}" page="{d.metadata.get("page","")}">\n'
            f"{content}\n</Document>"
        )
    return {key: "\n\n---\n\n".join(entries)}


@tool
def wiki_search(query: str) -> dict:
    """Search Wikipedia (2 docs) and format results."""
    docs = WikipediaLoader(query=query, load_max_docs=2).load()
    return format_docs(docs, "wiki_results")


@tool
def web_search(query: str) -> dict:
    """Search the web via Tavily (3 docs) and format results."""
    docs = TavilySearchResults(max_results=3).invoke(query=query)
    return format_docs(docs, "web_results")


@tool
def arxiv_search(query: str) -> dict:
    """Search ArXiv (3 docs) and format results (truncate to 1k chars)."""
    docs = ArxivLoader(query=query, load_max_docs=3).load()
    return format_docs(docs, "arxiv_results", max_chars=1000)


def build_vector_retriever():
    """
    Create and return a Supabase-based vector retriever.

    Returns:
        Retriever for semantic similarity queries.
    """
    embed = HuggingFaceEmbeddings("sentence-transformers/all-mpnet-base-v2")
    supa = create_client(
        os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY")
    )
    store = SupabaseVectorStore(
        client=supa,
        embedding=embed,
        table_name="documents",
        query_name="match_documents_langchain",
    )
    return store.as_retriever()


def get_llm(provider: str = "google"):
    """
    Factory to select and return an LLM client.

    Args:
        provider: One of "google", "groq", "huggingface".

    Returns:
        Configured LLM client.

    Raises:
        ValueError: On unsupported provider.
    """
    if provider == "google":
        return ChatGoogleGenerativeAI("gemini-2.0-flash", temperature=0)
    if provider == "groq":
        return ChatGroq("qwen-qwq-32b", temperature=0)
    if provider == "huggingface":
        return ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url="https://api-inference.huggingface.co/models/"
                "Meta-DeepLearning/llama-2-7b-chat-hf",
                temperature=0,
            )
        )
    raise ValueError(f"Unsupported provider: {provider}")


def build_graph(provider: str = "google"):
    """
    Build and compile a StateGraph for retrieval + LLM responses.

    Args:
        provider: LLM provider key.

    Returns:
        A compiled StateGraph.
    """
    sys_msg = load_system_prompt()
    retriever = build_vector_retriever()
    question_tool = create_retriever_tool(
        retriever=retriever,
        name="Question Search",
        description="Retrieve similar Q&A from vector store.",
    )

    tools = [
        add,
        subtract,
        multiply,
        divide,
        modulus,
        wiki_search,
        web_search,
        arxiv_search,
        question_tool,
    ]
    llm = get_llm(provider).bind_tools(tools)

    def retriever_node(state: MessagesState) -> dict:
        """
        Node: retrieve most relevant doc and extract its answer.
        """
        query = state["messages"][-1].content
        doc = retriever.similarity_search(query, k=1)[0]
        text = doc.page_content
        ans = text.split("Final answer :")[-1].strip() if "Final answer :" in text else text
        return {"messages": [AIMessage(content=ans)]}

    def assistant_node(state: MessagesState) -> dict:
        """
        Node: call LLM with system prompt + history.
        """
        msgs = [sys_msg] + state["messages"]
        resp = llm.invoke({"messages": msgs})
        return {"messages": [resp]}

    graph = StateGraph(MessagesState)
    graph.add_node("retriever", retriever_node)
    graph.add_node("assistant", assistant_node)
    graph.add_node("tools", ToolNode(tools))
    graph.add_edge(START, "retriever")
    graph.add_edge("retriever", "assistant")
    graph.add_conditional_edges("assistant", tools_condition)
    graph.add_edge("tools", "assistant")
    graph.set_entry_point("retriever")
    graph.set_finish_point("assistant")

    return graph.compile()