Spaces:
Sleeping
Sleeping
Agent implementation
Browse files- .gitignore +1 -0
- .python-version +1 -0
- __pycache__/tools.cpython-312.pyc +0 -0
- __pycache__/utils.cpython-312.pyc +0 -0
- agent.py +81 -0
- app.py +6 -3
- create_vector_database.ipynb +109 -0
- hello.py +6 -0
- metadata.jsonl +0 -0
- prompt.yaml +2 -0
- pyproject.toml +7 -0
- tools.py +85 -0
- utils.py +10 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.env
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.12
|
__pycache__/tools.cpython-312.pyc
ADDED
Binary file (3.63 kB). View file
|
|
__pycache__/utils.cpython-312.pyc
ADDED
Binary file (825 Bytes). View file
|
|
agent.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
from langgraph.graph import START, StateGraph, MessagesState
|
4 |
+
from langgraph.prebuilt import tools_condition
|
5 |
+
from langgraph.prebuilt import ToolNode
|
6 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
7 |
+
from langchain_community.vectorstores import SupabaseVectorStore
|
8 |
+
from langchain_core.messages import HumanMessage
|
9 |
+
from langchain.tools.retriever import create_retriever_tool
|
10 |
+
from supabase.client import Client, create_client
|
11 |
+
from utils import load_prompt
|
12 |
+
from tools import calculator, duck_web_search, wiki_search, arxiv_search
|
13 |
+
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
# Create retriever
|
17 |
+
embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") # dim=768
|
18 |
+
|
19 |
+
supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
|
20 |
+
vector_store = SupabaseVectorStore(
|
21 |
+
client=supabase,
|
22 |
+
embedding= embeddings,
|
23 |
+
table_name="gaia_documents",
|
24 |
+
query_name="match_documents_langchain",
|
25 |
+
)
|
26 |
+
|
27 |
+
retriever = create_retriever_tool(
|
28 |
+
retriever=vector_store.as_retriever(),
|
29 |
+
name="ModernBERT Retriever",
|
30 |
+
description="A retriever of similar questions from a vector store.",
|
31 |
+
)
|
32 |
+
|
33 |
+
tools = [calculator, duck_web_search, wiki_search, arxiv_search]
|
34 |
+
|
35 |
+
model_id = "Qwen/Qwen3-32B"
|
36 |
+
|
37 |
+
llm = HuggingFaceEndpoint(
|
38 |
+
repo_id=model_id,
|
39 |
+
temperature=0,
|
40 |
+
repetition_penalty=1.03,
|
41 |
+
provider="auto",
|
42 |
+
huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY")
|
43 |
+
)
|
44 |
+
|
45 |
+
agent = ChatHuggingFace(llm=llm)
|
46 |
+
|
47 |
+
agent_with_tools = agent.bind_tools(tools)
|
48 |
+
|
49 |
+
def retriever_node(state: MessagesState):
|
50 |
+
"""RAG node"""
|
51 |
+
similar_question = vector_store.similarity_search(state["messages"][0].content)
|
52 |
+
response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")]
|
53 |
+
return {"messages": response}
|
54 |
+
|
55 |
+
def processor_node(state: MessagesState):
|
56 |
+
|
57 |
+
system_prompt = load_prompt("prompt.yaml")
|
58 |
+
|
59 |
+
messages = state.get("messages", [])
|
60 |
+
response = [agent_with_tools.invoke([system_prompt] + messages)]
|
61 |
+
"""Agent node that answers questions"""
|
62 |
+
return {"messages": response}
|
63 |
+
|
64 |
+
def agent_graph():
|
65 |
+
builder = StateGraph(MessagesState)
|
66 |
+
|
67 |
+
## Add nodes
|
68 |
+
builder.add_node("retriever_node", retriever_node)
|
69 |
+
builder.add_node("processor_node", processor_node)
|
70 |
+
builder.add_node("tools", ToolNode(tools))
|
71 |
+
|
72 |
+
## Add edges
|
73 |
+
builder.add_edge(START, "retriever_node")
|
74 |
+
builder.add_edge("retriever_node", "processor_node")
|
75 |
+
builder.add_conditional_edges("processor_node", tools_condition)
|
76 |
+
builder.add_edge("tools", "processor_node")
|
77 |
+
|
78 |
+
# Compile graph
|
79 |
+
builder.compile()
|
80 |
+
|
81 |
+
return builder
|
app.py
CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
|
|
3 |
import requests
|
4 |
import inspect
|
5 |
import pandas as pd
|
|
|
6 |
|
7 |
# (Keep Constants as is)
|
8 |
# --- Constants ---
|
@@ -12,10 +13,12 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
12 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
13 |
class BasicAgent:
|
14 |
def __init__(self):
|
|
|
15 |
print("BasicAgent initialized.")
|
16 |
def __call__(self, question: str) -> str:
|
17 |
-
|
18 |
-
|
|
|
19 |
print(f"Agent returning fixed answer: {fixed_answer}")
|
20 |
return fixed_answer
|
21 |
|
@@ -40,7 +43,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
40 |
|
41 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
42 |
try:
|
43 |
-
agent =
|
44 |
except Exception as e:
|
45 |
print(f"Error instantiating agent: {e}")
|
46 |
return f"Error initializing agent: {e}", None
|
|
|
3 |
import requests
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
+
from agent import agent_graph
|
7 |
|
8 |
# (Keep Constants as is)
|
9 |
# --- Constants ---
|
|
|
13 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
14 |
class BasicAgent:
|
15 |
def __init__(self):
|
16 |
+
self.agent = agent_graph()
|
17 |
print("BasicAgent initialized.")
|
18 |
def __call__(self, question: str) -> str:
|
19 |
+
messages = [HumanMessage(content=question)]
|
20 |
+
response = self.graph.invoke({"messages": messages})
|
21 |
+
fixed_answer = response['messages'][-1].content[16:]
|
22 |
print(f"Agent returning fixed answer: {fixed_answer}")
|
23 |
return fixed_answer
|
24 |
|
|
|
43 |
|
44 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
45 |
try:
|
46 |
+
agent = agent_graph()
|
47 |
except Exception as e:
|
48 |
print(f"Error instantiating agent: {e}")
|
49 |
return f"Error initializing agent: {e}", None
|
create_vector_database.ipynb
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "a9f7a25f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"/home/kpatelis/projects/Agents_Course_Assignment/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
14 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
15 |
+
]
|
16 |
+
}
|
17 |
+
],
|
18 |
+
"source": [
|
19 |
+
"import os\n",
|
20 |
+
"import json\n",
|
21 |
+
"from dotenv import load_dotenv\n",
|
22 |
+
"from supabase.client import Client, create_client\n",
|
23 |
+
"from langchain_huggingface import HuggingFaceEmbeddings\n",
|
24 |
+
"from langchain.schema import Document\n",
|
25 |
+
"\n",
|
26 |
+
"load_dotenv()"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": null,
|
32 |
+
"id": "2c948d46",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"supabase: Client = create_client(\n",
|
37 |
+
" os.environ.get(\"SUPABASE_URL\"), \n",
|
38 |
+
" os.environ.get(\"SUPABASE_SERVICE_KEY\"))\n",
|
39 |
+
"\n",
|
40 |
+
"embeddings = HuggingFaceEmbeddings(model_name=\"Alibaba-NLP/gte-modernbert-base\")"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 15,
|
46 |
+
"id": "f2c5492b",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"with open('metadata.jsonl', 'r') as jsonl_file:\n",
|
51 |
+
" json_list = list(jsonl_file)\n",
|
52 |
+
"\n",
|
53 |
+
"documents = []\n",
|
54 |
+
"for json_str in json_list:\n",
|
55 |
+
" json_data = json.loads(json_str)\n",
|
56 |
+
" content = f\"Question : {json_data['Question']}\\n\\nFinal answer : {json_data['Final answer']}\"\n",
|
57 |
+
" embedding = embeddings.embed_query(content)\n",
|
58 |
+
" document = {\n",
|
59 |
+
" \"content\" : content,\n",
|
60 |
+
" \"metadata\" : {\n",
|
61 |
+
" \"source\" : json_data['task_id']\n",
|
62 |
+
" },\n",
|
63 |
+
" \"embedding\" : embedding,\n",
|
64 |
+
" }\n",
|
65 |
+
" documents.append(document)"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": null,
|
71 |
+
"id": "26ddbafd",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# pgvector needs to be enabled, to turn to vector database\n",
|
76 |
+
"# Table needs to be created beforehand in Supabase, with column types\n",
|
77 |
+
"try:\n",
|
78 |
+
" response = (\n",
|
79 |
+
" supabase.table(\"gaia_documents\")\n",
|
80 |
+
" .insert(documents)\n",
|
81 |
+
" .execute()\n",
|
82 |
+
" )\n",
|
83 |
+
"except Exception as exception:\n",
|
84 |
+
" print(\"Error inserting data into Supabase:\", exception)"
|
85 |
+
]
|
86 |
+
}
|
87 |
+
],
|
88 |
+
"metadata": {
|
89 |
+
"kernelspec": {
|
90 |
+
"display_name": ".venv",
|
91 |
+
"language": "python",
|
92 |
+
"name": "python3"
|
93 |
+
},
|
94 |
+
"language_info": {
|
95 |
+
"codemirror_mode": {
|
96 |
+
"name": "ipython",
|
97 |
+
"version": 3
|
98 |
+
},
|
99 |
+
"file_extension": ".py",
|
100 |
+
"mimetype": "text/x-python",
|
101 |
+
"name": "python",
|
102 |
+
"nbconvert_exporter": "python",
|
103 |
+
"pygments_lexer": "ipython3",
|
104 |
+
"version": "3.12.3"
|
105 |
+
}
|
106 |
+
},
|
107 |
+
"nbformat": 4,
|
108 |
+
"nbformat_minor": 5
|
109 |
+
}
|
hello.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def main():
|
2 |
+
print("Hello from agents-course-assignment!")
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
main()
|
metadata.jsonl
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompt.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
title: "Agent"
|
2 |
+
prompt: "You are a helpful AI assistant, equiped with a set of tools. Your task is to answer questions provided by the user. You should always make use of tools to answer the questions. For any given question, think and formulate a response. You can describe your thought process and use tool calls to assist you in answering the question. The final answer should be either a number, or a concise reply with as fewer words as possible, or a comma separated list of numbers and/or strings. When the answer is a number, do not use comma or any units or currency signs to write your number. If you are asked for a number, do not use comma to write your number neither use units such as $ or percent sign unless specified otherwise. When the answer is a string, do not use articles or abbreviations (e.g. city names). When the answer is a list, apply the above rules, depending on if the item in the list is a string or number. Your answer should only start with FINAL ANSWER: <answer>."
|
pyproject.toml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "agents-course-assignment"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Add your description here"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.12"
|
7 |
+
dependencies = []
|
tools.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from langchain_community.tools import DuckDuckGoSearchRun
|
4 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
5 |
+
from langchain_community.document_loaders import WikipediaLoader
|
6 |
+
from langchain_community.document_loaders import ArxivLoader
|
7 |
+
from langchain_core.tools import tool
|
8 |
+
|
9 |
+
@tool
|
10 |
+
def calculator(a: float, b: float, type: str) -> float:
|
11 |
+
"""Performs mathematical calculations, addition, subtraction, multiplication, division, modulus.
|
12 |
+
Args:
|
13 |
+
a (float): first float number
|
14 |
+
b (float): second float number
|
15 |
+
type (str): the type of calculation to perform, can be addition, subtraction, multiplication, division, modulus
|
16 |
+
"""
|
17 |
+
|
18 |
+
if type == "addition":
|
19 |
+
return a + b
|
20 |
+
elif type == "subtraction":
|
21 |
+
return a - b
|
22 |
+
elif type == "multiplication":
|
23 |
+
return a * b
|
24 |
+
elif type == "division":
|
25 |
+
if b == 0:
|
26 |
+
raise ValueError("Cannot divide by zero.")
|
27 |
+
return a / b
|
28 |
+
elif type == "modulus":
|
29 |
+
a % b
|
30 |
+
else:
|
31 |
+
TypeError(f"{type} is not an option for type, choose one of addition, subtraction, multiplication, division, modulus")
|
32 |
+
|
33 |
+
@tool
|
34 |
+
def duck_web_search(query: str) -> str:
|
35 |
+
"""Use DuckDuckGo to search the web.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
query: The search query.
|
39 |
+
"""
|
40 |
+
search = DuckDuckGoSearchRun().invoke(query=query)
|
41 |
+
|
42 |
+
return {"duckduckgo_web_search": search}
|
43 |
+
|
44 |
+
@tool
|
45 |
+
def wiki_search(query: str) -> str:
|
46 |
+
"""Search Wikipedia for a query and return maximum 3 results.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
query: The search query."""
|
50 |
+
documents = WikipediaLoader(query=query, load_max_docs=3).load()
|
51 |
+
processed_documents = "\n\n---\n\n".join(
|
52 |
+
[
|
53 |
+
f"Document title: {document.metadata.get("title", "")}. Summary: {document.metadata.get("summary", "")}. Documents details: {document.page_content}"
|
54 |
+
for document in documents
|
55 |
+
])
|
56 |
+
return {"wiki_results": processed_documents}
|
57 |
+
|
58 |
+
@tool
|
59 |
+
def arxiv_search(query: str) -> str:
|
60 |
+
"""Search Arxiv for a query and return maximum 3 result.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
query: The search query."""
|
64 |
+
documents = ArxivLoader(query=query, load_max_docs=3).load()
|
65 |
+
processed_documents = "\n\n---\n\n".join(
|
66 |
+
[
|
67 |
+
f"Document title: {document.metadata.get("title", "")}. Summary: {document.metadata.get("summary", "")}. Documents details: {document.page_content}"
|
68 |
+
for document in documents
|
69 |
+
])
|
70 |
+
return {"arxiv_results": processed_documents}
|
71 |
+
|
72 |
+
@tool
|
73 |
+
def tavily_web_search(query: str) -> str:
|
74 |
+
"""Search the web using Tavily for a query and return maximum 3 results.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
query: The search query."""
|
78 |
+
search_engine = TavilySearchResults(max_results=3)
|
79 |
+
search_documents = search_engine.invoke(input=query)
|
80 |
+
web_results = "\n\n---\n\n".join(
|
81 |
+
[
|
82 |
+
f"Document title: {document["title"]}. Contents: {document["content"]}. Relevance Score: {document["score"]}"
|
83 |
+
for document in search_documents
|
84 |
+
])
|
85 |
+
return {"web_results": web_results}
|
utils.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from langchain_core.messages import SystemMessage
|
3 |
+
|
4 |
+
def load_prompt(prompt_location):
|
5 |
+
with open(prompt_location) as f:
|
6 |
+
try:
|
7 |
+
prompt = yaml.safe_load(f)["prompt"]
|
8 |
+
return SystemMessage(content=prompt)
|
9 |
+
except yaml.YAMLError as exc:
|
10 |
+
print(exc)
|