Harshana commited on
Commit
372720a
·
1 Parent(s): 1aee18a

add basic code

Browse files
.env ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ LLM_PROVIDER=groq
2
+ SUPABASE_URL=https://YOUR.supabase.co
3
+ SUPABASE_SERVICE_KEY=your-supabase-service-key
4
+ SYSTEM_PROMPT_PATH=system_prompt.txt
agent.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import settings
2
+ from llm_provider import get_llm
3
+ from tools import ALL_TOOLS
4
+ from retrievers import custom_retriever
5
+ from langgraph.graph import START, StateGraph, MessagesState
6
+ from langgraph.prebuilt import tools_condition, ToolNode
7
+ from langchain_core.messages import SystemMessage, HumanMessage
8
+
9
+ # Load system prompt
10
+ with open(settings.system_prompt_path, "r", encoding="utf-8") as f:
11
+ system_prompt = f.read()
12
+
13
+ sys_msg = SystemMessage(content=system_prompt)
14
+
15
+ def build_graph():
16
+ llm = get_llm(settings.llm_provider)
17
+ llm_with_tools = llm.bind_tools(ALL_TOOLS)
18
+
19
+ def assistant(state: MessagesState):
20
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
21
+
22
+ def retriever(state: MessagesState):
23
+ similar_q = custom_retriever.retrieve(state["messages"][0].content)
24
+ example_msg = HumanMessage(content=f"Similar Q&A:\n\n{similar_q}")
25
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
26
+
27
+ builder = StateGraph(MessagesState)
28
+ builder.add_node("retriever", retriever)
29
+ builder.add_node("assistant", assistant)
30
+ builder.add_node("tools", ToolNode(ALL_TOOLS))
31
+ builder.add_edge(START, "retriever")
32
+ builder.add_edge("retriever", "assistant")
33
+ builder.add_conditional_edges("assistant", tools_condition)
34
+ builder.add_edge("tools", "assistant")
35
+
36
+ return builder.compile()
37
+
38
+ if __name__ == "__main__":
39
+ graph = build_graph()
40
+ question = input("Ask your question: ")
41
+ messages = [HumanMessage(content=question)]
42
+ results = graph.invoke({"messages": messages})
43
+ for m in results["messages"]:
44
+ print(m.content)
config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ class Settings:
7
+ provider = os.getenv("LLM_PROVIDER", "groq")
8
+ supabase_url = os.getenv("SUPABASE_URL")
9
+ supabase_key = os.getenv("SUPABASE_SERVICE_KEY")
10
+ # Add other settings
11
+
12
+ settings = Settings()
llm_provider.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_google_genai import ChatGoogleGenerativeAI
2
+ from langchain_groq import ChatGroq
3
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
4
+ from config import settings
5
+
6
+ def get_llm(provider: str):
7
+ if provider == "google":
8
+ return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
9
+ elif provider == "groq":
10
+ return ChatGroq(model="qwen-qwq-32b", temperature=0)
11
+ elif provider == "huggingface":
12
+ return ChatHuggingFace(
13
+ llm=HuggingFaceEndpoint(
14
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
15
+ temperature=0,
16
+ ),
17
+ )
18
+ else:
19
+ raise ValueError(f"Unknown provider: {provider}")
prompt/__init__.py ADDED
File without changes
prompt/system_prompt.py ADDED
File without changes
requirements.txt CHANGED
@@ -1,2 +1,18 @@
1
  gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
+ requests
3
+ langchain
4
+ langchain-community
5
+ langchain-core
6
+ langchain-google-genai
7
+ langchain-huggingface
8
+ langchain-groq
9
+ langchain-tavily
10
+ langchain-chroma
11
+ langgraph
12
+ huggingface_hub
13
+ supabase
14
+ arxiv
15
+ pymupdf
16
+ wikipedia
17
+ pgvector
18
+ python-dotenv
retrievers/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .custom_retriever import retrieve
2
+
3
+ class CustomRetriever:
4
+ @staticmethod
5
+ def retrieve(query):
6
+ return retrieve(query)
7
+
8
+ custom_retriever = CustomRetriever()
retrievers/custom_retriever.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+ from langchain_community.vectorstores import SupabaseVectorStore
4
+ from supabase.client import create_client
5
+ from config import settings
6
+
7
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
8
+ supabase = create_client(settings.supabase_url, settings.supabase_key)
9
+ vector_store = SupabaseVectorStore(
10
+ client=supabase,
11
+ embedding=embeddings,
12
+ table_name="documents",
13
+ query_name="match_documents_langchain",
14
+ )
15
+
16
+ def retrieve(query: str) -> str:
17
+ results = vector_store.similarity_search(query)
18
+ if results:
19
+ return results[0].page_content
20
+ else:
21
+ return "No similar questions found."
tools/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # tools/__init__.py
2
+ # Automatically gather all @tool-decorated tools for registration
3
+
4
+ from .math_tools import add, multiply
5
+ from .search_tools import wiki_search
6
+
7
+ ALL_TOOLS = [add, multiply, wiki_search]
tools/math_tools.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+
3
+ @tool
4
+ def add(a: float, b: float) -> float:
5
+ """Add two numbers."""
6
+ return a + b
7
+
8
+ @tool
9
+ def subtract(a: float, b: float) -> float:
10
+ """Subtract b from a."""
11
+ return a - b
12
+
13
+ @tool
14
+ def multiply(a: float, b: float) -> float:
15
+ """Multiply two numbers."""
16
+ return a * b
17
+
18
+ @tool
19
+ def divide(a: float, b: float) -> float:
20
+ """Divide a by b. Raises an error if b is zero."""
21
+ if b == 0:
22
+ raise ValueError("Cannot divide by zero.")
23
+ return a / b
24
+
25
+ @tool
26
+ def modulus(a: float, b: float) -> float:
27
+ """Return the modulus (remainder) of a divided by b."""
28
+ return a % b
29
+
30
+ @tool
31
+ def power(a: float, b: float) -> float:
32
+ """Return a raised to the power of b."""
33
+ return a ** b
34
+
35
+ @tool
36
+ def sqrt(x: float) -> float:
37
+ """Return the square root of x. Raises error if x is negative."""
38
+ if x < 0:
39
+ raise ValueError("Cannot compute the square root of a negative number.")
40
+ return x ** 0.5
41
+
42
+ @tool
43
+ def abs_val(x: float) -> float:
44
+ """Return the absolute value of x."""
45
+ return abs(x)
tools/search_tools.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
3
+ from langchain_community.tools.tavily_search import TavilySearchResults
4
+
5
+ @tool
6
+ def wiki_search(query: str) -> dict:
7
+ """Search Wikipedia for a query and return maximum 2 results."""
8
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
9
+ formatted = "\n\n---\n\n".join(
10
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
11
+ for doc in search_docs
12
+ )
13
+ return {"wiki_results": formatted}
14
+
15
+ @tool
16
+ def web_search(query: str) -> dict:
17
+ """Search Tavily for a query and return maximum 3 results."""
18
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
19
+ formatted = "\n\n---\n\n".join(
20
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
21
+ for doc in search_docs
22
+ )
23
+ return {"web_results": formatted}
24
+
25
+ @tool
26
+ def arxiv_search(query: str) -> dict:
27
+ """Search Arxiv for a query and return maximum 3 results."""
28
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
29
+ formatted = "\n\n---\n\n".join(
30
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
31
+ for doc in search_docs
32
+ )
33
+ return {"arxiv_results": formatted}
tools/vector_tools.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ # Example vector tool using retriever
3
+ from retrievers import custom_retriever
4
+
5
+ @tool
6
+ def similar_question(query: str) -> str:
7
+ """Retrieve a similar question from the vector store."""
8
+ return custom_retriever.retrieve(query)