vtony commited on
Commit
7f6f6b7
·
verified ·
1 Parent(s): ab60a9f

Upload 2 files

Browse files
Files changed (2) hide show
  1. System_Prompt.txt +18 -0
  2. agent.py +241 -150
System_Prompt.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are GAIA (General AI Assistant). Follow these rules:
2
+
3
+ 1. Always use available tools when needed
4
+ 2. Format responses as:
5
+ REASONING: <step-by-step logic>
6
+ FINAL ANSWER: <concise result>
7
+
8
+ Available Tools:
9
+ - Math tools: multiply, add, subtract, divide, modulus
10
+ - Search tools: wiki_search (Wikipedia), arxiv_search (arXiv), web_search (web)
11
+
12
+ Example for Mercedes Sosa question:
13
+ USER: How many studio albums were published by Mercedes Sosa between 2000 and 2009?
14
+ REASONING:
15
+ 1. Use wiki_search to find Mercedes Sosa discography
16
+ 2. Filter results to studio albums between 2000-2009
17
+ 3. Count the matching albums
18
+ FINAL ANSWER: 3 studio albums
agent.py CHANGED
@@ -1,150 +1,241 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langgraph.graph import START, StateGraph, MessagesState
4
- from langgraph.prebuilt import tools_condition
5
- from langgraph.prebuilt import ToolNode
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
8
- from langchain_community.tools.tavily_search import TavilySearchResults
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_community.vectorstores import SupabaseVectorStore
11
- from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
12
- from langchain_core.tools import tool
13
- from langchain.tools.retriever import create_retriever_tool
14
- from supabase.client import Client, create_client
15
-
16
- # Load environment variables
17
- load_dotenv()
18
-
19
- # --- Math Tools ---
20
- @tool
21
- def multiply(a: int, b: int) -> int:
22
- """Multiply two integers."""
23
- return a * b
24
-
25
- @tool
26
- def add(a: int, b: int) -> int:
27
- """Add two integers."""
28
- return a + b
29
-
30
- @tool
31
- def subtract(a: int, b: int) -> int:
32
- """Subtract b from a."""
33
- return a - b
34
-
35
- @tool
36
- def divide(a: int, b: int) -> float:
37
- """Divide a by b, error on zero."""
38
- if b == 0:
39
- raise ValueError("Cannot divide by zero.")
40
- return a / b
41
-
42
- @tool
43
- def modulus(a: int, b: int) -> int:
44
- """Compute a mod b."""
45
- return a % b
46
-
47
- # --- Browser Tools ---
48
- @tool
49
- def wiki_search(query: str) -> dict:
50
- """Search Wikipedia and return up to 2 documents."""
51
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
- results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
53
- return {"wiki_results": "\n---\n".join(results)}
54
-
55
- @tool
56
- def web_search(query: str) -> dict:
57
- """Search Tavily and return up to 3 results."""
58
- docs = TavilySearchResults(max_results=3).invoke(query=query)
59
- results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content}" for d in docs]
60
- return {"web_results": "\n---\n".join(results)}
61
-
62
- @tool
63
- def arxiv_search(query: str) -> dict:
64
- """Search Arxiv and return up to 3 docs."""
65
- docs = ArxivLoader(query=query, load_max_docs=3).load()
66
- results = [f"<Document source=\"{d.metadata['source']}\" page=\"{d.metadata.get('page','')}\"/>\n{d.page_content[:1000]}" for d in docs]
67
- return {"arxiv_results": "\n---\n".join(results)}
68
-
69
- # --- Load system prompt ---
70
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
71
- system_prompt = f.read()
72
-
73
- # --- System message ---
74
- sys_msg = SystemMessage(content=system_prompt)
75
-
76
- # --- Retriever Tool ---
77
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
78
- supabase = create_client(os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_SERVICE_KEY"))
79
-
80
- vector_store = SupabaseVectorStore(
81
- client=supabase,
82
- embedding=embeddings, table_name="documents",
83
- query_name="match_documents_langchain")
84
-
85
- retriever_tool = create_retriever_tool(
86
- retriever=vector_store.as_retriever(
87
-
88
- search_type="similarity",
89
- search_kwargs={"k": 5}
90
- ),
91
- name="Question Search",
92
- description="A tool to retrieve similar questions from the vector store."
93
- )
94
-
95
- tools = [
96
- multiply,
97
- add,
98
- subtract,
99
- divide,
100
- modulus,
101
- wiki_search,
102
- web_search,
103
- arxiv_search,
104
- ]
105
-
106
- # --- Graph Builder ---
107
- def build_graph(provider: str = "huggingface"):
108
- llm = ChatHuggingFace(
109
- llm=HuggingFaceEndpoint(
110
- repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
111
- ),
112
- )
113
-
114
- # Bind tools to LLM
115
- llm_with_tools = llm.bind_tools(tools)
116
-
117
- # Define no def assistant(state: MessagesState):
118
  """Assistant node"""
119
-
120
- return {"messages [ [llm_with_tools.invoke(state["messages"])]}se]}
121
-
122
-
123
- # Retriever returns AIMessage def retriever(state: MessagesState):
124
- """Retriever node"""
125
- similar_question = vector_store.similarity_search(state["messages"][0].content)
126
- print('Similar questions:')
127
- print(similar_question)
128
- if len(similar_question) > 0:
129
- example_msg = HumanMessage(
130
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
131
- ntent}]}
132
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
133
- return {"messages": [sys_msg] + state["m
134
- # Add nodesessages"]}
135
-
136
- builder = StateGraph(MessagesState)
137
- builder.add_node("retriever", retriever)
138
- builder.add_node("assistant", assistant)
139
- builder.add_node("tools",
140
-
141
- # Add edgesToolNode(tools))
142
- builder.add_edge(START, "retriever")
143
- builder.add_edge("retriever", "assistant")
144
- builder.add_conditional_edges(
145
- "assistant",
146
- tools_condition,
147
- )
148
- builder.add_edge("tools", "assistant")ever")
149
-
150
- # Compile graph
151
- return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Assistant node"""
2
+ import os
3
+ import time
4
+ from dotenv import load_dotenv
5
+ from langgraph.graph import StateGraph, END
6
+ from langgraph.prebuilt import ToolNode, tools_condition
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_community.tools import DuckDuckGoSearchRun
9
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
+ from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
11
+ from langchain_core.tools import tool
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+
14
+ # Load environment variables
15
+ load_dotenv()
16
+ google_api_key = os.getenv("GOOGLE_API_KEY") or os.environ.get("GOOGLE_API_KEY")
17
+ if not google_api_key:
18
+ raise ValueError("Missing GOOGLE_API_KEY environment variable")
19
+
20
+ # --- Math Tools ---
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two integers."""
24
+ return a * b
25
+
26
+ @tool
27
+ def add(a: int, b: int) -> int:
28
+ """Add two integers."""
29
+ return a + b
30
+
31
+ @tool
32
+ def subtract(a: int, b: int) -> int:
33
+ """Subtract b from a."""
34
+ return a - b
35
+
36
+ @tool
37
+ def divide(a: int, b: int) -> float:
38
+ """Divide a by b, error on zero."""
39
+ if b == 0:
40
+ raise ValueError("Cannot divide by zero.")
41
+ return a / b
42
+
43
+ @tool
44
+ def modulus(a: int, b: int) -> int:
45
+ """Compute a mod b."""
46
+ return a % b
47
+
48
+ # --- Browser Tools ---
49
+ @tool
50
+ def wiki_search(query: str) -> str:
51
+ """Search Wikipedia and return up to 3 relevant documents."""
52
+ try:
53
+ docs = WikipediaLoader(query=query, load_max_docs=3).load()
54
+ if not docs:
55
+ return "No Wikipedia results found."
56
+
57
+ results = []
58
+ for doc in docs:
59
+ title = doc.metadata.get('title', 'Unknown Title')
60
+ content = doc.page_content[:2000] # Limit content length
61
+ results.append(f"Title: {title}\nContent: {content}")
62
+
63
+ return "\n\n---\n\n".join(results)
64
+ except Exception as e:
65
+ return f"Wikipedia search error: {str(e)}"
66
+
67
+ @tool
68
+ def arxiv_search(query: str) -> str:
69
+ """Search Arxiv and return up to 3 relevant papers."""
70
+ try:
71
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
72
+ if not docs:
73
+ return "No arXiv papers found."
74
+
75
+ results = []
76
+ for doc in docs:
77
+ title = doc.metadata.get('Title', 'Unknown Title')
78
+ authors = ", ".join(doc.metadata.get('Authors', []))
79
+ content = doc.page_content[:2000] # Limit content length
80
+ results.append(f"Title: {title}\nAuthors: {authors}\nContent: {content}")
81
+
82
+ return "\n\n---\n\n".join(results)
83
+ except Exception as e:
84
+ return f"arXiv search error: {str(e)}"
85
+
86
+ @tool
87
+ def web_search(query: str) -> str:
88
+ """Search the web using DuckDuckGo and return top results."""
89
+ try:
90
+ search = DuckDuckGoSearchRun()
91
+ result = search.run(query)
92
+ return f"Web search results for '{query}':\n{result[:2000]}" # Limit content length
93
+ except Exception as e:
94
+ return f"Web search error: {str(e)}"
95
+
96
+ # --- Load system prompt ---
97
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
98
+ system_prompt = f.read()
99
+
100
+ # --- System message ---
101
+ sys_msg = SystemMessage(content=system_prompt)
102
+
103
+ # --- Tool Setup ---
104
+ tools = [
105
+ multiply,
106
+ add,
107
+ subtract,
108
+ divide,
109
+ modulus,
110
+ wiki_search,
111
+ arxiv_search,
112
+ web_search,
113
+ ]
114
+
115
+ # --- Graph Builder ---
116
+ def build_graph():
117
+ # Initialize model (Gemini 2.0 Flash)
118
+ llm = ChatGoogleGenerativeAI(
119
+ model="gemini-2.0-flash-exp",
120
+ temperature=0.3,
121
+ google_api_key=google_api_key,
122
+ max_retries=3
123
+ )
124
+
125
+ # Bind tools to LLM
126
+ llm_with_tools = llm.bind_tools(tools)
127
+
128
+ # Define state
129
+ class AgentState:
130
+ def __init__(self, messages):
131
+ self.messages = messages
132
+
133
+ # Node definitions with error handling
134
+ def agent_node(state: AgentState):
135
+ """Main agent node that processes messages with retry logic"""
136
+ try:
137
+ # Add rate limiting
138
+ time.sleep(1) # 1 second delay between requests
139
+
140
+ # Add retry logic for API quota issues
141
+ @retry(stop=stop_after_attempt(3),
142
+ wait=wait_exponential(multiplier=1, min=4, max=10))
143
+ def invoke_llm_with_retry():
144
+ return llm_with_tools.invoke(state.messages)
145
+
146
+ response = invoke_llm_with_retry()
147
+ return AgentState(state.messages + [response])
148
+
149
+ except Exception as e:
150
+ # Handle specific errors
151
+ error_type = "UNKNOWN"
152
+ if "429" in str(e):
153
+ error_type = "QUOTA_EXCEEDED"
154
+ elif "400" in str(e):
155
+ error_type = "INVALID_REQUEST"
156
+
157
+ error_msg = f"AGENT ERROR ({error_type}): {str(e)[:200]}"
158
+ return AgentState(state.messages + [AIMessage(content=error_msg)])
159
+
160
+ # Tool node
161
+ def tool_node(state: AgentState):
162
+ """Execute tools based on agent's request"""
163
+ last_message = state.messages[-1]
164
+ tool_calls = last_message.additional_kwargs.get("tool_calls", [])
165
+
166
+ tool_responses = []
167
+ for tool_call in tool_calls:
168
+ tool_name = tool_call["function"]["name"]
169
+ tool_args = tool_call["function"].get("arguments", {})
170
+
171
+ # Find the tool
172
+ tool_func = next((t for t in tools if t.name == tool_name), None)
173
+ if not tool_func:
174
+ tool_responses.append(f"Tool {tool_name} not found")
175
+ continue
176
+
177
+ try:
178
+ # Execute the tool
179
+ if isinstance(tool_args, str):
180
+ # Parse JSON if arguments are in string format
181
+ import json
182
+ tool_args = json.loads(tool_args)
183
+
184
+ result = tool_func.invoke(tool_args)
185
+ tool_responses.append(f"Tool {tool_name} result: {result}")
186
+ except Exception as e:
187
+ tool_responses.append(f"Tool {tool_name} error: {str(e)}")
188
+
189
+ return AgentState(state.messages + [AIMessage(content="\n".join(tool_responses)])
190
+
191
+ # Custom condition function
192
+ def should_continue(state: AgentState):
193
+ last_message = state.messages[-1]
194
+
195
+ # If there was an error, end
196
+ if "AGENT ERROR" in last_message.content:
197
+ return "end"
198
+
199
+ # Check for tool calls
200
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
201
+ return "tools"
202
+
203
+ # Check for final answer
204
+ if "FINAL ANSWER" in last_message.content:
205
+ return "end"
206
+
207
+ # Otherwise, continue to agent
208
+ return "agent"
209
+
210
+ # Build the graph
211
+ workflow = StateGraph(AgentState)
212
+
213
+ # Add nodes
214
+ workflow.add_node("agent", agent_node)
215
+ workflow.add_node("tools", tool_node)
216
+
217
+ # Set entry point
218
+ workflow.set_entry_point("agent")
219
+
220
+ # Define edges
221
+ workflow.add_conditional_edges(
222
+ "agent",
223
+ should_continue,
224
+ {
225
+ "agent": "agent",
226
+ "tools": "tools",
227
+ "end": END
228
+ }
229
+ )
230
+
231
+ workflow.add_conditional_edges(
232
+ "tools",
233
+ lambda state: "agent",
234
+ {
235
+ "agent": "agent"
236
+ }
237
+ )
238
+
239
+ return workflow.compile()
240
+
241
+ # Initialize the agent graph
242
+ agent_graph = build_graph()