Files changed (1) hide show
  1. agent.py +120 -0
agent.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent using Mistral"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.vectorstores import SupabaseVectorStore
10
+ from langchain_core.messages import SystemMessage, HumanMessage
11
+ from langchain_core.tools import tool
12
+ from transformers import pipeline
13
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
14
+ from supabase.client import Client, create_client
15
+
16
+ load_dotenv()
17
+
18
+ # Tools
19
+ @tool
20
+ def multiply(a: int, b: int) -> int:
21
+ return a * b
22
+
23
+ @tool
24
+ def add(a: int, b: int) -> int:
25
+ return a + b
26
+
27
+ @tool
28
+ def subtract(a: int, b: int) -> int:
29
+ return a - b
30
+
31
+ @tool
32
+ def divide(a: int, b: int) -> float:
33
+ if b == 0:
34
+ raise ValueError("Cannot divide by zero.")
35
+ return a / b
36
+
37
+ @tool
38
+ def modulus(a: int, b: int) -> int:
39
+ return a % b
40
+
41
+ @tool
42
+ def wiki_search(query: str) -> str:
43
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
44
+ return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
45
+
46
+ @tool
47
+ def web_search(query: str) -> str:
48
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
49
+ return "\n\n---\n\n".join([doc.page_content for doc in search_docs])
50
+
51
+ @tool
52
+ def arvix_search(query: str) -> str:
53
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
54
+ return "\n\n---\n\n".join([doc.page_content[:1000] for doc in search_docs])
55
+
56
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
57
+
58
+ # Load system prompt
59
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
60
+ system_prompt = f.read()
61
+ sys_msg = SystemMessage(content=system_prompt)
62
+
63
+ # Vector store setup
64
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
65
+ supabase: Client = create_client(
66
+ os.environ.get("SUPABASE_URL"),
67
+ os.environ.get("SUPABASE_SERVICE_KEY")
68
+ )
69
+ vector_store = SupabaseVectorStore(
70
+ client=supabase,
71
+ embedding=embeddings,
72
+ table_name="documents",
73
+ query_name="match_documents_langchain"
74
+ )
75
+
76
+ # Mistral agent
77
+ class MistralAgent:
78
+ def __init__(self):
79
+ self.generator = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1", device=0)
80
+ print("Mistral model loaded.")
81
+
82
+ def invoke(self, messages):
83
+ question = messages[-1].content
84
+ result = self.generator(question, max_length=300, do_sample=True)[0]["generated_text"]
85
+ return HumanMessage(content=result.strip())
86
+
87
+ mistral_agent = MistralAgent()
88
+
89
+ # LangGraph builder
90
+ def build_graph():
91
+ def assistant(state: MessagesState):
92
+ return {"messages": [mistral_agent.invoke(state["messages"])]}
93
+
94
+ def retriever(state: MessagesState):
95
+ similar = vector_store.similarity_search(state["messages"][-1].content)
96
+ example = HumanMessage(content=f"Similar Q&A:\n\n{similar[0].page_content}")
97
+ return {"messages": [sys_msg] + state["messages"] + [example]}
98
+
99
+ builder = StateGraph(MessagesState)
100
+ builder.add_node("retriever", retriever)
101
+ builder.add_node("assistant", assistant)
102
+ builder.add_node("tools", ToolNode(tools))
103
+ builder.add_edge(START, "retriever")
104
+ builder.add_edge("retriever", "assistant")
105
+ builder.add_conditional_edges("assistant", tools_condition)
106
+ builder.add_edge("tools", "assistant")
107
+
108
+ return builder.compile()
109
+
110
+ # Run the agent
111
+ def run_agent(question: str) -> str:
112
+ graph = build_graph()
113
+ messages = [HumanMessage(content=question)]
114
+ result = graph.invoke({"messages": messages})
115
+ return result["messages"][-1].content.strip()
116
+
117
+ if __name__ == "__main__":
118
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
119
+ answer = run_agent(question)
120
+ print("ANSWER:", answer)