KPatelis commited on
Commit
b3f9415
·
1 Parent(s): 81917a3

Agent implementation

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
__pycache__/tools.cpython-312.pyc ADDED
Binary file (3.63 kB). View file
 
__pycache__/utils.cpython-312.pyc ADDED
Binary file (825 Bytes). View file
 
agent.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import SupabaseVectorStore
8
+ from langchain_core.messages import HumanMessage
9
+ from langchain.tools.retriever import create_retriever_tool
10
+ from supabase.client import Client, create_client
11
+ from utils import load_prompt
12
+ from tools import calculator, duck_web_search, wiki_search, arxiv_search
13
+
14
+ load_dotenv()
15
+
16
+ # Create retriever
17
+ embeddings = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-modernbert-base") # dim=768
18
+
19
+ supabase: Client = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
20
+ vector_store = SupabaseVectorStore(
21
+ client=supabase,
22
+ embedding= embeddings,
23
+ table_name="gaia_documents",
24
+ query_name="match_documents_langchain",
25
+ )
26
+
27
+ retriever = create_retriever_tool(
28
+ retriever=vector_store.as_retriever(),
29
+ name="ModernBERT Retriever",
30
+ description="A retriever of similar questions from a vector store.",
31
+ )
32
+
33
+ tools = [calculator, duck_web_search, wiki_search, arxiv_search]
34
+
35
+ model_id = "Qwen/Qwen3-32B"
36
+
37
+ llm = HuggingFaceEndpoint(
38
+ repo_id=model_id,
39
+ temperature=0,
40
+ repetition_penalty=1.03,
41
+ provider="auto",
42
+ huggingfacehub_api_token=os.getenv("HF_INFERENCE_KEY")
43
+ )
44
+
45
+ agent = ChatHuggingFace(llm=llm)
46
+
47
+ agent_with_tools = agent.bind_tools(tools)
48
+
49
+ def retriever_node(state: MessagesState):
50
+ """RAG node"""
51
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
52
+ response = [HumanMessage(f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}")]
53
+ return {"messages": response}
54
+
55
+ def processor_node(state: MessagesState):
56
+
57
+ system_prompt = load_prompt("prompt.yaml")
58
+
59
+ messages = state.get("messages", [])
60
+ response = [agent_with_tools.invoke([system_prompt] + messages)]
61
+ """Agent node that answers questions"""
62
+ return {"messages": response}
63
+
64
+ def agent_graph():
65
+ builder = StateGraph(MessagesState)
66
+
67
+ ## Add nodes
68
+ builder.add_node("retriever_node", retriever_node)
69
+ builder.add_node("processor_node", processor_node)
70
+ builder.add_node("tools", ToolNode(tools))
71
+
72
+ ## Add edges
73
+ builder.add_edge(START, "retriever_node")
74
+ builder.add_edge("retriever_node", "processor_node")
75
+ builder.add_conditional_edges("processor_node", tools_condition)
76
+ builder.add_edge("tools", "processor_node")
77
+
78
+ # Compile graph
79
+ builder.compile()
80
+
81
+ return builder
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -12,10 +13,12 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
 
15
  print("BasicAgent initialized.")
16
  def __call__(self, question: str) -> str:
17
- print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
 
19
  print(f"Agent returning fixed answer: {fixed_answer}")
20
  return fixed_answer
21
 
@@ -40,7 +43,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
40
 
41
  # 1. Instantiate Agent ( modify this part to create your agent)
42
  try:
43
- agent = BasicAgent()
44
  except Exception as e:
45
  print(f"Error instantiating agent: {e}")
46
  return f"Error initializing agent: {e}", None
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from agent import agent_graph
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
 
13
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
14
  class BasicAgent:
15
  def __init__(self):
16
+ self.agent = agent_graph()
17
  print("BasicAgent initialized.")
18
  def __call__(self, question: str) -> str:
19
+ messages = [HumanMessage(content=question)]
20
+ response = self.graph.invoke({"messages": messages})
21
+ fixed_answer = response['messages'][-1].content[16:]
22
  print(f"Agent returning fixed answer: {fixed_answer}")
23
  return fixed_answer
24
 
 
43
 
44
  # 1. Instantiate Agent ( modify this part to create your agent)
45
  try:
46
+ agent = agent_graph()
47
  except Exception as e:
48
  print(f"Error instantiating agent: {e}")
49
  return f"Error initializing agent: {e}", None
create_vector_database.ipynb ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a9f7a25f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/kpatelis/projects/Agents_Course_Assignment/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ }
17
+ ],
18
+ "source": [
19
+ "import os\n",
20
+ "import json\n",
21
+ "from dotenv import load_dotenv\n",
22
+ "from supabase.client import Client, create_client\n",
23
+ "from langchain_huggingface import HuggingFaceEmbeddings\n",
24
+ "from langchain.schema import Document\n",
25
+ "\n",
26
+ "load_dotenv()"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "2c948d46",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "supabase: Client = create_client(\n",
37
+ " os.environ.get(\"SUPABASE_URL\"), \n",
38
+ " os.environ.get(\"SUPABASE_SERVICE_KEY\"))\n",
39
+ "\n",
40
+ "embeddings = HuggingFaceEmbeddings(model_name=\"Alibaba-NLP/gte-modernbert-base\")"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 15,
46
+ "id": "f2c5492b",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "with open('metadata.jsonl', 'r') as jsonl_file:\n",
51
+ " json_list = list(jsonl_file)\n",
52
+ "\n",
53
+ "documents = []\n",
54
+ "for json_str in json_list:\n",
55
+ " json_data = json.loads(json_str)\n",
56
+ " content = f\"Question : {json_data['Question']}\\n\\nFinal answer : {json_data['Final answer']}\"\n",
57
+ " embedding = embeddings.embed_query(content)\n",
58
+ " document = {\n",
59
+ " \"content\" : content,\n",
60
+ " \"metadata\" : {\n",
61
+ " \"source\" : json_data['task_id']\n",
62
+ " },\n",
63
+ " \"embedding\" : embedding,\n",
64
+ " }\n",
65
+ " documents.append(document)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "id": "26ddbafd",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# pgvector needs to be enabled, to turn to vector database\n",
76
+ "# Table needs to be created beforehand in Supabase, with column types\n",
77
+ "try:\n",
78
+ " response = (\n",
79
+ " supabase.table(\"gaia_documents\")\n",
80
+ " .insert(documents)\n",
81
+ " .execute()\n",
82
+ " )\n",
83
+ "except Exception as exception:\n",
84
+ " print(\"Error inserting data into Supabase:\", exception)"
85
+ ]
86
+ }
87
+ ],
88
+ "metadata": {
89
+ "kernelspec": {
90
+ "display_name": ".venv",
91
+ "language": "python",
92
+ "name": "python3"
93
+ },
94
+ "language_info": {
95
+ "codemirror_mode": {
96
+ "name": "ipython",
97
+ "version": 3
98
+ },
99
+ "file_extension": ".py",
100
+ "mimetype": "text/x-python",
101
+ "name": "python",
102
+ "nbconvert_exporter": "python",
103
+ "pygments_lexer": "ipython3",
104
+ "version": "3.12.3"
105
+ }
106
+ },
107
+ "nbformat": 4,
108
+ "nbformat_minor": 5
109
+ }
hello.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from agents-course-assignment!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
metadata.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
prompt.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ title: "Agent"
2
+ prompt: "You are a helpful AI assistant, equiped with a set of tools. Your task is to answer questions provided by the user. You should always make use of tools to answer the questions. For any given question, think and formulate a response. You can describe your thought process and use tool calls to assist you in answering the question. The final answer should be either a number, or a concise reply with as fewer words as possible, or a comma separated list of numbers and/or strings. When the answer is a number, do not use comma or any units or currency signs to write your number. If you are asked for a number, do not use comma to write your number neither use units such as $ or percent sign unless specified otherwise. When the answer is a string, do not use articles or abbreviations (e.g. city names). When the answer is a list, apply the above rules, depending on if the item in the list is a string or number. Your answer should only start with FINAL ANSWER: <answer>."
pyproject.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "agents-course-assignment"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = []
tools.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_community.tools import DuckDuckGoSearchRun
4
+ from langchain_community.tools.tavily_search import TavilySearchResults
5
+ from langchain_community.document_loaders import WikipediaLoader
6
+ from langchain_community.document_loaders import ArxivLoader
7
+ from langchain_core.tools import tool
8
+
9
+ @tool
10
+ def calculator(a: float, b: float, type: str) -> float:
11
+ """Performs mathematical calculations, addition, subtraction, multiplication, division, modulus.
12
+ Args:
13
+ a (float): first float number
14
+ b (float): second float number
15
+ type (str): the type of calculation to perform, can be addition, subtraction, multiplication, division, modulus
16
+ """
17
+
18
+ if type == "addition":
19
+ return a + b
20
+ elif type == "subtraction":
21
+ return a - b
22
+ elif type == "multiplication":
23
+ return a * b
24
+ elif type == "division":
25
+ if b == 0:
26
+ raise ValueError("Cannot divide by zero.")
27
+ return a / b
28
+ elif type == "modulus":
29
+ a % b
30
+ else:
31
+ TypeError(f"{type} is not an option for type, choose one of addition, subtraction, multiplication, division, modulus")
32
+
33
+ @tool
34
+ def duck_web_search(query: str) -> str:
35
+ """Use DuckDuckGo to search the web.
36
+
37
+ Args:
38
+ query: The search query.
39
+ """
40
+ search = DuckDuckGoSearchRun().invoke(query=query)
41
+
42
+ return {"duckduckgo_web_search": search}
43
+
44
+ @tool
45
+ def wiki_search(query: str) -> str:
46
+ """Search Wikipedia for a query and return maximum 3 results.
47
+
48
+ Args:
49
+ query: The search query."""
50
+ documents = WikipediaLoader(query=query, load_max_docs=3).load()
51
+ processed_documents = "\n\n---\n\n".join(
52
+ [
53
+ f"Document title: {document.metadata.get("title", "")}. Summary: {document.metadata.get("summary", "")}. Documents details: {document.page_content}"
54
+ for document in documents
55
+ ])
56
+ return {"wiki_results": processed_documents}
57
+
58
+ @tool
59
+ def arxiv_search(query: str) -> str:
60
+ """Search Arxiv for a query and return maximum 3 result.
61
+
62
+ Args:
63
+ query: The search query."""
64
+ documents = ArxivLoader(query=query, load_max_docs=3).load()
65
+ processed_documents = "\n\n---\n\n".join(
66
+ [
67
+ f"Document title: {document.metadata.get("title", "")}. Summary: {document.metadata.get("summary", "")}. Documents details: {document.page_content}"
68
+ for document in documents
69
+ ])
70
+ return {"arxiv_results": processed_documents}
71
+
72
+ @tool
73
+ def tavily_web_search(query: str) -> str:
74
+ """Search the web using Tavily for a query and return maximum 3 results.
75
+
76
+ Args:
77
+ query: The search query."""
78
+ search_engine = TavilySearchResults(max_results=3)
79
+ search_documents = search_engine.invoke(input=query)
80
+ web_results = "\n\n---\n\n".join(
81
+ [
82
+ f"Document title: {document["title"]}. Contents: {document["content"]}. Relevance Score: {document["score"]}"
83
+ for document in search_documents
84
+ ])
85
+ return {"web_results": web_results}
utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from langchain_core.messages import SystemMessage
3
+
4
+ def load_prompt(prompt_location):
5
+ with open(prompt_location) as f:
6
+ try:
7
+ prompt = yaml.safe_load(f)["prompt"]
8
+ return SystemMessage(content=prompt)
9
+ except yaml.YAMLError as exc:
10
+ print(exc)