mnab commited on
Commit
3e72c2b
·
verified ·
1 Parent(s): 29d27d3

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +240 -83
  2. app.py +2 -2
agent.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
2
  from langchain_openai import ChatOpenAI
3
  from langchain_core.messages import AnyMessage, SystemMessage
@@ -23,84 +25,208 @@ from langchain_huggingface import (
23
  HuggingFaceEmbeddings,
24
  )
25
 
 
 
 
26
 
27
  load_dotenv()
28
 
29
 
30
- # Initialize the DuckDuckGo search tool
31
- search_tool = DuckDuckGoSearchResults()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  @tool
35
- def wiki_search(query: str) -> str:
36
- """Search Wikipedia for a query and return maximum 2 results.
 
37
 
38
  Args:
39
- query: The search query."""
40
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
41
- formatted_search_docs = "\n\n---\n\n".join(
42
- [
43
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
44
- for doc in search_docs
45
- ]
46
- )
47
- return {"wiki_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  @tool
51
- def web_search(query: str) -> str:
52
- """Search Tavily for a query and return maximum 3 results.
 
53
 
54
  Args:
55
- query: The search query."""
56
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
57
- formatted_search_docs = "\n\n---\n\n".join(
58
- [
59
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
60
- for doc in search_docs
61
- ]
62
- )
63
- return {"web_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  @tool
67
- def arvix_search(query: str) -> str:
68
- """Search Arxiv for a query and return maximum 3 result.
 
69
 
70
  Args:
71
- query: The search query."""
72
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
73
- formatted_search_docs = "\n\n---\n\n".join(
74
- [
75
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
76
- for doc in search_docs
77
- ]
78
- )
79
- return {"arvix_results": formatted_search_docs}
80
-
81
-
82
- # Load LLM model
83
- llm = ChatOpenAI(
84
- model="gpt-4o",
85
- base_url="https://models.inference.ai.azure.com",
86
- api_key=os.environ["GITHUB_TOKEN"],
87
- temperature=0.2,
88
- max_tokens=4096,
89
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  # llm = ChatHuggingFace(
91
  # llm=HuggingFaceEndpoint(
92
  # # repo_id="microsoft/Phi-3-mini-4k-instruct",
93
- # repo_id="Qwen/Qwen3-235B-A22B",
94
  # temperature=0,
95
  # # huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
96
  # ),
97
  # verbose=True,
98
  # )
99
-
 
 
100
  tools = [
101
- arvix_search,
102
- wiki_search,
103
- # web_search,
 
 
104
  # search_tool,
105
  ]
106
  # Bind the tools to the LLM
@@ -108,30 +234,63 @@ model_with_tools = llm.bind_tools(tools)
108
  tool_node = ToolNode(tools)
109
 
110
 
111
- def build_agent_workflow():
 
112
 
113
- def should_continue(state: MessagesState):
114
- messages = state["messages"]
115
- last_message = messages[-1]
116
- if last_message.tool_calls:
117
- return "tools"
118
- return END
119
-
120
- def call_model(state: MessagesState):
121
- system_message = SystemMessage(
122
- content=f"""
123
- You are a helpful assistant tasked with answering questions using a set of tools.
124
- Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
125
- FINAL ANSWER: [YOUR FINAL ANSWER].
126
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
127
- Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
128
- )
129
 
130
- messages = [system_message] + state["messages"]
131
- print("Messages to LLM:", messages)
132
 
133
- response = model_with_tools.invoke(messages)
134
- return {"messages": [response]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  # Define the state graph
137
  workflow = StateGraph(MessagesState)
@@ -139,20 +298,18 @@ def build_agent_workflow():
139
  workflow.add_node("tools", tool_node)
140
 
141
  workflow.add_edge(START, "agent")
142
- workflow.add_conditional_edges("agent", should_continue, ["tools", END])
143
  workflow.add_edge("tools", "agent")
144
-
145
  app = workflow.compile()
146
-
147
  return app
148
 
149
 
150
- if __name__ == "__main__":
151
- question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?"
152
- # Build the graph
153
- graph = build_agent_workflow()
154
- # Run the graph
155
- messages = [HumanMessage(content=question)]
156
- messages = graph.invoke({"messages": messages})
157
- for m in messages["messages"]:
158
- m.pretty_print()
 
1
+ import tempfile
2
+ from urllib.parse import urlparse
3
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
4
  from langchain_openai import ChatOpenAI
5
  from langchain_core.messages import AnyMessage, SystemMessage
 
25
  HuggingFaceEmbeddings,
26
  )
27
 
28
+ from langchain_google_genai import ChatGoogleGenerativeAI
29
+ import requests
30
+ from huggingface_hub import login
31
 
32
  load_dotenv()
33
 
34
 
35
+ @tool
36
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
37
+ """
38
+ Save content to a temporary file and return the path.
39
+ Useful for processing files from the GAIA API.
40
+
41
+ Args:
42
+ content: The content to save to the file
43
+ filename: Optional filename, will generate a random name if not provided
44
+
45
+ Returns:
46
+ Path to the saved file
47
+ """
48
+ temp_dir = tempfile.gettempdir()
49
+ if filename is None:
50
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
51
+ filepath = temp_file.name
52
+ else:
53
+ filepath = os.path.join(temp_dir, filename)
54
+
55
+ # Write content to the file
56
+ with open(filepath, "w") as f:
57
+ f.write(content)
58
+
59
+ return f"File saved to {filepath}. You can read this file to process its contents."
60
 
61
 
62
  @tool
63
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
64
+ """
65
+ Download a file from a URL and save it to a temporary location.
66
 
67
  Args:
68
+ url: The URL to download from
69
+ filename: Optional filename, will generate one based on URL if not provided
70
+
71
+ Returns:
72
+ Path to the downloaded file
73
+ """
74
+ try:
75
+ # Parse URL to get filename if not provided
76
+ if not filename:
77
+ path = urlparse(url).path
78
+ filename = os.path.basename(path)
79
+ if not filename:
80
+ # Generate a random name if we couldn't extract one
81
+ import uuid
82
+
83
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
84
+
85
+ # Create temporary file
86
+ temp_dir = tempfile.gettempdir()
87
+ filepath = os.path.join(temp_dir, filename)
88
+
89
+ # Download the file
90
+ response = requests.get(url, stream=True)
91
+ response.raise_for_status()
92
+
93
+ # Save the file
94
+ with open(filepath, "wb") as f:
95
+ for chunk in response.iter_content(chunk_size=8192):
96
+ f.write(chunk)
97
+
98
+ return f"File downloaded to {filepath}. You can now process this file."
99
+ except Exception as e:
100
+ return f"Error downloading file: {str(e)}"
101
 
102
 
103
  @tool
104
+ def extract_text_from_image(image_path: str) -> str:
105
+ """
106
+ Extract text from an image using pytesseract (if available).
107
 
108
  Args:
109
+ image_path: Path to the image file
110
+
111
+ Returns:
112
+ Extracted text or error message
113
+ """
114
+ try:
115
+ # Try to import pytesseract
116
+ import pytesseract
117
+ from PIL import Image
118
+
119
+ # Open the image
120
+ image = Image.open(image_path)
121
+
122
+ # Extract text
123
+ text = pytesseract.image_to_string(image)
124
+
125
+ return f"Extracted text from image:\n\n{text}"
126
+ except ImportError:
127
+ return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
128
+ except Exception as e:
129
+ return f"Error extracting text from image: {str(e)}"
130
 
131
 
132
  @tool
133
+ def analyze_csv_file(file_path: str, query: str) -> str:
134
+ """
135
+ Analyze a CSV file using pandas and answer a question about it.
136
 
137
  Args:
138
+ file_path: Path to the CSV file
139
+ query: Question about the data
140
+
141
+ Returns:
142
+ Analysis result or error message
143
+ """
144
+ try:
145
+ import pandas as pd
146
+
147
+ # Read the CSV file
148
+ df = pd.read_csv(file_path)
149
+
150
+ # Run various analyses based on the query
151
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
152
+ result += f"Columns: {', '.join(df.columns)}\n\n"
153
+
154
+ # Add summary statistics
155
+ result += "Summary statistics:\n"
156
+ result += str(df.describe())
157
+
158
+ return result
159
+ except ImportError:
160
+ return "Error: pandas is not installed. Please install it with 'pip install pandas'."
161
+ except Exception as e:
162
+ return f"Error analyzing CSV file: {str(e)}"
163
+
164
+
165
+ @tool
166
+ def analyze_excel_file(file_path: str, query: str) -> str:
167
+ """
168
+ Analyze an Excel file using pandas and answer a question about it.
169
+
170
+ Args:
171
+ file_path: Path to the Excel file
172
+ query: Question about the data
173
+
174
+ Returns:
175
+ Analysis result or error message
176
+ """
177
+ try:
178
+ import pandas as pd
179
+
180
+ # Read the Excel file
181
+ df = pd.read_excel(file_path)
182
+
183
+ # Run various analyses based on the query
184
+ result = (
185
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
186
+ )
187
+ result += f"Columns: {', '.join(df.columns)}\n\n"
188
+
189
+ # Add summary statistics
190
+ result += "Summary statistics:\n"
191
+ result += str(df.describe())
192
+
193
+ return result
194
+ except ImportError:
195
+ return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
196
+ except Exception as e:
197
+ return f"Error analyzing Excel file: {str(e)}"
198
+
199
+
200
+ # Initialize the DuckDuckGo search tool
201
+ search_tool = DuckDuckGoSearchResults()
202
+
203
+
204
+ # # Load LLM model
205
+ # llm = ChatOpenAI(
206
+ # model="gpt-4o",
207
+ # base_url="https://models.inference.ai.azure.com",
208
+ # api_key=os.environ["GITHUB_TOKEN"],
209
+ # temperature=0.2,
210
+ # max_tokens=4096,
211
+ # )
212
  # llm = ChatHuggingFace(
213
  # llm=HuggingFaceEndpoint(
214
  # # repo_id="microsoft/Phi-3-mini-4k-instruct",
215
+ # repo_id="meta-llama/Llama-3-70B-Instruct",
216
  # temperature=0,
217
  # # huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"],
218
  # ),
219
  # verbose=True,
220
  # )
221
+ llm = ChatGoogleGenerativeAI(
222
+ model="gemini-2.0-flash-exp", google_api_key=os.environ["GOOGLE_API_KEY"]
223
+ )
224
  tools = [
225
+ analyze_csv_file,
226
+ analyze_excel_file,
227
+ extract_text_from_image,
228
+ download_file_from_url,
229
+ save_and_read_file,
230
  # search_tool,
231
  ]
232
  # Bind the tools to the LLM
 
234
  tool_node = ToolNode(tools)
235
 
236
 
237
+ class AgentState(TypedDict):
238
+ """State of the agent."""
239
 
240
+ input_file: Optional[str]
241
+ messages: Annotated[list[AnyMessage], add_messages]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
 
 
243
 
244
+ def build_agent_workflow():
245
+ """Build the agent workflow."""
246
+
247
+ def call_model(state: AgentState):
248
+ print("State:", state["messages"])
249
+ question = state["messages"][-1].content
250
+ context = f"""
251
+ You are a helpful assistant tasked with answering questions using a set of tools.
252
+ """
253
+ # System message
254
+ if state.get("input_file"):
255
+ try:
256
+ with open(state.get("input_file"), "r") as f:
257
+ file_content = f.read()
258
+ print("File content:", file_content)
259
+
260
+ # Determine file type from extension
261
+ file_ext = os.path.splitext(state.get("input_file"))[1].lower()
262
+ context = f"""
263
+ Question: {question}
264
+ This question has an associated file. Here is the file content:
265
+ ```{file_ext}
266
+ {file_content}
267
+ ```
268
+ Analyze the file content above to answer the question."""
269
+ except Exception as file_e:
270
+ context = f""" Question: {state["message"]}
271
+ This question has an associated file at path: {state.get("input_file")}
272
+ However, there was an error reading the file: {file_e}
273
+ You can still try to answer the question based on the information provided.
274
+ """
275
+
276
+ if question.startswith(".") or ".rewsna eht sa" in question:
277
+ context = f"""
278
+ This question appears to be in reversed text. Here's the reversed version:
279
+ {state['message'][::-1]}
280
+ Now answer the question above. Remember to format your answer exactly as requested.
281
+ """
282
+ system_prompt = SystemMessage(
283
+ f"""{context}
284
+ When answering, provide ONLY the precise answer requested.
285
+ Do not include explanations, steps, reasoning, or additional text.
286
+ Be direct and specific. GAIA benchmark requires exact matching answers.
287
+ For example, if asked "What is the capital of France?", respond simply with "Paris".
288
+ """
289
+ )
290
+ return {
291
+ "messages": [model_with_tools.invoke([system_prompt] + state["messages"])],
292
+ # "input_file": state["input_file"],
293
+ }
294
 
295
  # Define the state graph
296
  workflow = StateGraph(MessagesState)
 
298
  workflow.add_node("tools", tool_node)
299
 
300
  workflow.add_edge(START, "agent")
301
+ workflow.add_conditional_edges("agent", tools_condition)
302
  workflow.add_edge("tools", "agent")
 
303
  app = workflow.compile()
 
304
  return app
305
 
306
 
307
+ # if __name__ == "__main__":
308
+ # question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?"
309
+ # # Build the graph
310
+ # graph = build_agent_workflow()
311
+ # # Run the graph
312
+ # messages = [HumanMessage(content=question)]
313
+ # messages = graph.invoke({"messages": messages, "input_file": None})
314
+ # for m in messages["messages"]:
315
+ # m.pretty_print()
app.py CHANGED
@@ -21,9 +21,9 @@ class BasicAgent:
21
  def __call__(self, question: str) -> str:
22
  print(f"Agent received question (first 50 chars): {question[:50]}...")
23
  messages = [HumanMessage(content=question)]
24
- messages = self.workflow.invoke({"messages": messages})
25
  answer = messages["messages"][-1].content
26
- return answer[14:]
27
  # fixed_answer = "This is a default answer."
28
  # print(f"Agent returning fixed answer: {fixed_answer}")
29
  # return fixed_answer
 
21
  def __call__(self, question: str) -> str:
22
  print(f"Agent received question (first 50 chars): {question[:50]}...")
23
  messages = [HumanMessage(content=question)]
24
+ messages = self.workflow.invoke({"messages": messages, "input_file": None})
25
  answer = messages["messages"][-1].content
26
+ return answer
27
  # fixed_answer = "This is a default answer."
28
  # print(f"Agent returning fixed answer: {fixed_answer}")
29
  # return fixed_answer