Spaces:
Sleeping
Sleeping
Harshana
commited on
Commit
·
372720a
1
Parent(s):
1aee18a
add basic code
Browse files- .env +4 -0
- agent.py +44 -0
- config.py +12 -0
- llm_provider.py +19 -0
- prompt/__init__.py +0 -0
- prompt/system_prompt.py +0 -0
- requirements.txt +17 -1
- retrievers/__init__.py +8 -0
- retrievers/custom_retriever.py +21 -0
- tools/__init__.py +7 -0
- tools/math_tools.py +45 -0
- tools/search_tools.py +33 -0
- tools/vector_tools.py +8 -0
.env
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLM_PROVIDER=groq
|
2 |
+
SUPABASE_URL=https://YOUR.supabase.co
|
3 |
+
SUPABASE_SERVICE_KEY=your-supabase-service-key
|
4 |
+
SYSTEM_PROMPT_PATH=system_prompt.txt
|
agent.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from config import settings
|
2 |
+
from llm_provider import get_llm
|
3 |
+
from tools import ALL_TOOLS
|
4 |
+
from retrievers import custom_retriever
|
5 |
+
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
+
from langgraph.prebuilt import tools_condition, ToolNode
|
7 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
8 |
+
|
9 |
+
# Load system prompt
|
10 |
+
with open(settings.system_prompt_path, "r", encoding="utf-8") as f:
|
11 |
+
system_prompt = f.read()
|
12 |
+
|
13 |
+
sys_msg = SystemMessage(content=system_prompt)
|
14 |
+
|
15 |
+
def build_graph():
|
16 |
+
llm = get_llm(settings.llm_provider)
|
17 |
+
llm_with_tools = llm.bind_tools(ALL_TOOLS)
|
18 |
+
|
19 |
+
def assistant(state: MessagesState):
|
20 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
21 |
+
|
22 |
+
def retriever(state: MessagesState):
|
23 |
+
similar_q = custom_retriever.retrieve(state["messages"][0].content)
|
24 |
+
example_msg = HumanMessage(content=f"Similar Q&A:\n\n{similar_q}")
|
25 |
+
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
26 |
+
|
27 |
+
builder = StateGraph(MessagesState)
|
28 |
+
builder.add_node("retriever", retriever)
|
29 |
+
builder.add_node("assistant", assistant)
|
30 |
+
builder.add_node("tools", ToolNode(ALL_TOOLS))
|
31 |
+
builder.add_edge(START, "retriever")
|
32 |
+
builder.add_edge("retriever", "assistant")
|
33 |
+
builder.add_conditional_edges("assistant", tools_condition)
|
34 |
+
builder.add_edge("tools", "assistant")
|
35 |
+
|
36 |
+
return builder.compile()
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
graph = build_graph()
|
40 |
+
question = input("Ask your question: ")
|
41 |
+
messages = [HumanMessage(content=question)]
|
42 |
+
results = graph.invoke({"messages": messages})
|
43 |
+
for m in results["messages"]:
|
44 |
+
print(m.content)
|
config.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
load_dotenv()
|
5 |
+
|
6 |
+
class Settings:
|
7 |
+
provider = os.getenv("LLM_PROVIDER", "groq")
|
8 |
+
supabase_url = os.getenv("SUPABASE_URL")
|
9 |
+
supabase_key = os.getenv("SUPABASE_SERVICE_KEY")
|
10 |
+
# Add other settings
|
11 |
+
|
12 |
+
settings = Settings()
|
llm_provider.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
2 |
+
from langchain_groq import ChatGroq
|
3 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
4 |
+
from config import settings
|
5 |
+
|
6 |
+
def get_llm(provider: str):
|
7 |
+
if provider == "google":
|
8 |
+
return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
9 |
+
elif provider == "groq":
|
10 |
+
return ChatGroq(model="qwen-qwq-32b", temperature=0)
|
11 |
+
elif provider == "huggingface":
|
12 |
+
return ChatHuggingFace(
|
13 |
+
llm=HuggingFaceEndpoint(
|
14 |
+
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
|
15 |
+
temperature=0,
|
16 |
+
),
|
17 |
+
)
|
18 |
+
else:
|
19 |
+
raise ValueError(f"Unknown provider: {provider}")
|
prompt/__init__.py
ADDED
File without changes
|
prompt/system_prompt.py
ADDED
File without changes
|
requirements.txt
CHANGED
@@ -1,2 +1,18 @@
|
|
1 |
gradio
|
2 |
-
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
gradio
|
2 |
+
requests
|
3 |
+
langchain
|
4 |
+
langchain-community
|
5 |
+
langchain-core
|
6 |
+
langchain-google-genai
|
7 |
+
langchain-huggingface
|
8 |
+
langchain-groq
|
9 |
+
langchain-tavily
|
10 |
+
langchain-chroma
|
11 |
+
langgraph
|
12 |
+
huggingface_hub
|
13 |
+
supabase
|
14 |
+
arxiv
|
15 |
+
pymupdf
|
16 |
+
wikipedia
|
17 |
+
pgvector
|
18 |
+
python-dotenv
|
retrievers/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .custom_retriever import retrieve
|
2 |
+
|
3 |
+
class CustomRetriever:
|
4 |
+
@staticmethod
|
5 |
+
def retrieve(query):
|
6 |
+
return retrieve(query)
|
7 |
+
|
8 |
+
custom_retriever = CustomRetriever()
|
retrievers/custom_retriever.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
3 |
+
from langchain_community.vectorstores import SupabaseVectorStore
|
4 |
+
from supabase.client import create_client
|
5 |
+
from config import settings
|
6 |
+
|
7 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
8 |
+
supabase = create_client(settings.supabase_url, settings.supabase_key)
|
9 |
+
vector_store = SupabaseVectorStore(
|
10 |
+
client=supabase,
|
11 |
+
embedding=embeddings,
|
12 |
+
table_name="documents",
|
13 |
+
query_name="match_documents_langchain",
|
14 |
+
)
|
15 |
+
|
16 |
+
def retrieve(query: str) -> str:
|
17 |
+
results = vector_store.similarity_search(query)
|
18 |
+
if results:
|
19 |
+
return results[0].page_content
|
20 |
+
else:
|
21 |
+
return "No similar questions found."
|
tools/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# tools/__init__.py
|
2 |
+
# Automatically gather all @tool-decorated tools for registration
|
3 |
+
|
4 |
+
from .math_tools import add, multiply
|
5 |
+
from .search_tools import wiki_search
|
6 |
+
|
7 |
+
ALL_TOOLS = [add, multiply, wiki_search]
|
tools/math_tools.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
|
3 |
+
@tool
|
4 |
+
def add(a: float, b: float) -> float:
|
5 |
+
"""Add two numbers."""
|
6 |
+
return a + b
|
7 |
+
|
8 |
+
@tool
|
9 |
+
def subtract(a: float, b: float) -> float:
|
10 |
+
"""Subtract b from a."""
|
11 |
+
return a - b
|
12 |
+
|
13 |
+
@tool
|
14 |
+
def multiply(a: float, b: float) -> float:
|
15 |
+
"""Multiply two numbers."""
|
16 |
+
return a * b
|
17 |
+
|
18 |
+
@tool
|
19 |
+
def divide(a: float, b: float) -> float:
|
20 |
+
"""Divide a by b. Raises an error if b is zero."""
|
21 |
+
if b == 0:
|
22 |
+
raise ValueError("Cannot divide by zero.")
|
23 |
+
return a / b
|
24 |
+
|
25 |
+
@tool
|
26 |
+
def modulus(a: float, b: float) -> float:
|
27 |
+
"""Return the modulus (remainder) of a divided by b."""
|
28 |
+
return a % b
|
29 |
+
|
30 |
+
@tool
|
31 |
+
def power(a: float, b: float) -> float:
|
32 |
+
"""Return a raised to the power of b."""
|
33 |
+
return a ** b
|
34 |
+
|
35 |
+
@tool
|
36 |
+
def sqrt(x: float) -> float:
|
37 |
+
"""Return the square root of x. Raises error if x is negative."""
|
38 |
+
if x < 0:
|
39 |
+
raise ValueError("Cannot compute the square root of a negative number.")
|
40 |
+
return x ** 0.5
|
41 |
+
|
42 |
+
@tool
|
43 |
+
def abs_val(x: float) -> float:
|
44 |
+
"""Return the absolute value of x."""
|
45 |
+
return abs(x)
|
tools/search_tools.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
3 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
4 |
+
|
5 |
+
@tool
|
6 |
+
def wiki_search(query: str) -> dict:
|
7 |
+
"""Search Wikipedia for a query and return maximum 2 results."""
|
8 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
9 |
+
formatted = "\n\n---\n\n".join(
|
10 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
11 |
+
for doc in search_docs
|
12 |
+
)
|
13 |
+
return {"wiki_results": formatted}
|
14 |
+
|
15 |
+
@tool
|
16 |
+
def web_search(query: str) -> dict:
|
17 |
+
"""Search Tavily for a query and return maximum 3 results."""
|
18 |
+
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
|
19 |
+
formatted = "\n\n---\n\n".join(
|
20 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
21 |
+
for doc in search_docs
|
22 |
+
)
|
23 |
+
return {"web_results": formatted}
|
24 |
+
|
25 |
+
@tool
|
26 |
+
def arxiv_search(query: str) -> dict:
|
27 |
+
"""Search Arxiv for a query and return maximum 3 results."""
|
28 |
+
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
|
29 |
+
formatted = "\n\n---\n\n".join(
|
30 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
31 |
+
for doc in search_docs
|
32 |
+
)
|
33 |
+
return {"arxiv_results": formatted}
|
tools/vector_tools.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
# Example vector tool using retriever
|
3 |
+
from retrievers import custom_retriever
|
4 |
+
|
5 |
+
@tool
|
6 |
+
def similar_question(query: str) -> str:
|
7 |
+
"""Retrieve a similar question from the vector store."""
|
8 |
+
return custom_retriever.retrieve(query)
|