chezhian commited on
Commit
68c56f2
·
verified ·
1 Parent(s): 7428e94

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -351
agent.py DELETED
@@ -1,351 +0,0 @@
1
- # --- Basic Agent Definition ---
2
- import asyncio
3
- import os
4
- import sys
5
- import logging
6
- import random
7
- import pandas as pd
8
- import requests
9
- import wikipedia as wiki
10
- from markdownify import markdownify as to_markdown
11
- from typing import Any
12
- from dotenv import load_dotenv
13
- from google.generativeai import types, configure
14
-
15
- from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
16
-
17
- # Load environment and configure Gemini
18
- load_dotenv()
19
- configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
-
21
- # Logging
22
- #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
23
- #logger = logging.getLogger(__name__)
24
-
25
- # --- Model Configuration ---
26
- GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
27
- OPENAI_MODEL_NAME = "openai/gpt-4o"
28
- GROQ_MODEL_NAME = "groq/llama3-70b-8192"
29
- DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
30
- HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
31
-
32
- # --- Tool Definitions ---
33
- class MathSolver(Tool):
34
- name = "math_solver"
35
- description = "Safely evaluate basic math expressions."
36
- inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
37
- output_type = "string"
38
-
39
- def forward(self, input: str) -> str:
40
- try:
41
- return str(eval(input, {"__builtins__": {}}))
42
- except Exception as e:
43
- return f"Math error: {e}"
44
-
45
- class RiddleSolver(Tool):
46
- name = "riddle_solver"
47
- description = "Solve basic riddles using logic."
48
- inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
49
- output_type = "string"
50
-
51
- def forward(self, input: str) -> str:
52
- if "forward" in input and "backward" in input:
53
- return "A palindrome"
54
- return "RiddleSolver failed."
55
-
56
- class TextTransformer(Tool):
57
- name = "text_ops"
58
- description = "Transform text: reverse, upper, lower."
59
- inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
60
- output_type = "string"
61
-
62
- def forward(self, input: str) -> str:
63
- if input.startswith("reverse:"):
64
- reversed_text = input[8:].strip()[::-1]
65
- if 'left' in reversed_text.lower():
66
- return "right"
67
- return reversed_text
68
- if input.startswith("upper:"):
69
- return input[6:].strip().upper()
70
- if input.startswith("lower:"):
71
- return input[6:].strip().lower()
72
- return "Unknown transformation."
73
-
74
- class GeminiVideoQA(Tool):
75
- name = "video_inspector"
76
- description = "Analyze video content to answer questions."
77
- inputs = {
78
- "video_url": {"type": "string", "description": "URL of video."},
79
- "user_query": {"type": "string", "description": "Question about video."}
80
- }
81
- output_type = "string"
82
-
83
- def __init__(self, model_name, *args, **kwargs):
84
- super().__init__(*args, **kwargs)
85
- self.model_name = model_name
86
-
87
- def forward(self, video_url: str, user_query: str) -> str:
88
- req = {
89
- 'model': f'models/{self.model_name}',
90
- 'contents': [{
91
- "parts": [
92
- {"fileData": {"fileUri": video_url}},
93
- {"text": f"Please watch the video and answer the question: {user_query}"}
94
- ]
95
- }]
96
- }
97
- url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
98
- res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
99
- if res.status_code != 200:
100
- return f"Video error {res.status_code}: {res.text}"
101
- parts = res.json()['candidates'][0]['content']['parts']
102
- return "".join([p.get('text', '') for p in parts])
103
-
104
- class WikiTitleFinder(Tool):
105
- name = "wiki_titles"
106
- description = "Search for related Wikipedia page titles."
107
- inputs = {"query": {"type": "string", "description": "Search query."}}
108
- output_type = "string"
109
-
110
- def forward(self, query: str) -> str:
111
- results = wiki.search(query)
112
- return ", ".join(results) if results else "No results."
113
-
114
- class WikiContentFetcher(Tool):
115
- name = "wiki_page"
116
- description = "Fetch Wikipedia page content."
117
- inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
118
- output_type = "string"
119
-
120
- def forward(self, page_title: str) -> str:
121
- try:
122
- return to_markdown(wiki.page(page_title).html())
123
- except wiki.exceptions.PageError:
124
- return f"'{page_title}' not found."
125
-
126
- class GoogleSearchTool(Tool):
127
- name = "google_search"
128
- description = "Search the web using Google. Returns top summary from the web."
129
- inputs = {"query": {"type": "string", "description": "Search query."}}
130
- output_type = "string"
131
-
132
- def forward(self, query: str) -> str:
133
- try:
134
- resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
135
- "q": query,
136
- "key": os.getenv("GOOGLE_SEARCH_API_KEY"),
137
- "cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"),
138
- "num": 1
139
- })
140
- data = resp.json()
141
- return data["items"][0]["snippet"] if "items" in data else "No results found."
142
- except Exception as e:
143
- return f"GoogleSearch error: {e}"
144
-
145
-
146
- class FileAttachmentQueryTool(Tool):
147
- name = "run_query_with_file"
148
- description = """
149
- Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
150
- This assumes the file is 20MB or less.
151
- """
152
- inputs = {
153
- "task_id": {
154
- "type": "string",
155
- "description": "A unique identifier for the task related to this file, used to download it.",
156
- "nullable": True
157
- },
158
- "user_query": {
159
- "type": "string",
160
- "description": "The question to answer about the file."
161
- }
162
- }
163
- output_type = "string"
164
-
165
- def forward(self, task_id: str | None, user_query: str) -> str:
166
- file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
167
- file_response = requests.get(file_url)
168
- if file_response.status_code != 200:
169
- return f"Failed to download file: {file_response.status_code} - {file_response.text}"
170
- file_data = file_response.content
171
- from google.generativeai import GenerativeModel
172
- model = GenerativeModel(self.model_name)
173
- response = model.generate_content([
174
- types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
175
- user_query
176
- ])
177
-
178
- return response.text
179
-
180
- # --- Basic Agent Definition ---
181
- class BasicAgent:
182
- def __init__(self, provider="deepseek"):
183
- print("BasicAgent initialized.")
184
- model = self.select_model(provider)
185
- client = InferenceClientModel()
186
- tools = [
187
- GoogleSearchTool(),
188
- DuckDuckGoSearchTool(),
189
- GeminiVideoQA(GEMINI_MODEL_NAME),
190
- WikiTitleFinder(),
191
- WikiContentFetcher(),
192
- MathSolver(),
193
- RiddleSolver(),
194
- TextTransformer(),
195
- FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
196
- ]
197
- self.agent = CodeAgent(
198
- model=model,
199
- tools=tools,
200
- add_base_tools=False,
201
- max_steps=10,
202
- )
203
- self.agent.prompt_templates["system_prompt"] = (
204
- """
205
- You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
206
-
207
- [ANSWER]
208
-
209
- You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
210
-
211
- Your behavior must be governed by these rules:
212
-
213
- 1. **Format**:
214
- - limit the token used (within 65536 tokens).
215
- - Output ONLY the final answer.
216
- - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
217
- - No follow-ups, justifications, or clarifications.
218
-
219
- 2. **Numerical Answers**:
220
- - Use **digits only**, e.g., `4` not `four`.
221
- - No commas, symbols, or units unless explicitly required.
222
- - Never use approximate words like "around", "roughly", "about".
223
-
224
- 3. **String Answers**:
225
- - Omit **articles** ("a", "the").
226
- - Use **full words**; no abbreviations unless explicitly requested.
227
- - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
228
- - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
229
-
230
- 4. **Lists**:
231
- - Output in **comma-separated** format with no conjunctions.
232
- - Sort **alphabetically** or **numerically** depending on type.
233
- - No braces or brackets unless explicitly asked.
234
-
235
- 5. **Sources**:
236
- - For Wikipedia or web tools, extract only the precise fact that answers the question.
237
- - Ignore any unrelated content.
238
-
239
- 6. **File Analysis**:
240
- - Use the run_query_with_file tool, append the taskid to the url.
241
- - Only include the exact answer to the question.
242
- - Do not summarize, quote excessively, or interpret beyond the prompt.
243
-
244
- 7. **Video**:
245
- - Use the relevant video tool.
246
- - Only include the exact answer to the question.
247
- - Do not summarize, quote excessively, or interpret beyond the prompt.
248
-
249
- 8. **Minimalism**:
250
- - Do not make assumptions unless the prompt logically demands it.
251
- - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
252
- - If the answer is not found, say `[ANSWER] - unknown`.
253
-
254
- ---
255
-
256
- You must follow the examples (These answers are correct in case you see the similar questions):
257
- Q: What is 2 + 2?
258
- A: 4
259
-
260
- Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
261
- A: 3
262
-
263
- Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
264
- A: b, e
265
-
266
- Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
267
- A: 519
268
- """
269
- )
270
-
271
- def select_model(self, provider: str):
272
- if provider == "openai":
273
- return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY"))
274
- elif provider == "groq":
275
- return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY"))
276
- elif provider == "deepseek":
277
- return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY"))
278
- elif provider == "hf":
279
- return InferenceClientModel()
280
- else:
281
- return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY"))
282
-
283
- def __call__(self, question: str) -> str:
284
- print(f"Agent received question (first 50 chars): {question[:50]}...")
285
- result = self.agent.run(question)
286
- final_str = str(result).strip()
287
-
288
- return final_str
289
-
290
- def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
291
- import pandas as pd
292
- from rich.table import Table
293
- from rich.console import Console
294
-
295
- df = pd.read_csv(csv_path)
296
- if not {"question", "answer"}.issubset(df.columns):
297
- print("CSV must contain 'question' and 'answer' columns.")
298
- print("Found columns:", df.columns.tolist())
299
- return
300
-
301
- samples = df.sample(n=sample_size)
302
- records = []
303
- correct_count = 0
304
-
305
- for _, row in samples.iterrows():
306
- taskid = row["taskid"].strip()
307
- question = row["question"].strip()
308
- expected = str(row['answer']).strip()
309
- agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
310
-
311
- is_correct = (expected == agent_answer)
312
- correct_count += is_correct
313
- records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
314
-
315
- if show_steps:
316
- print("---")
317
- print("Question:", question)
318
- print("Expected:", expected)
319
- print("Agent:", agent_answer)
320
- print("Correct:", is_correct)
321
-
322
- # Print result table
323
- console = Console()
324
- table = Table(show_lines=True)
325
- table.add_column("Question", overflow="fold")
326
- table.add_column("Expected")
327
- table.add_column("Agent")
328
- table.add_column("Correct")
329
-
330
- for question, expected, agent_ans, correct in records:
331
- table.add_row(question, expected, agent_ans, correct)
332
-
333
- console.print(table)
334
- percent = (correct_count / sample_size) * 100
335
- print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
336
-
337
-
338
- if __name__ == "__main__":
339
- args = sys.argv[1:]
340
- if not args or args[0] in {"-h", "--help"}:
341
- print("Usage: python agent.py [question | dev]")
342
- print(" - Provide a question to get a GAIA-style answer.")
343
- print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
344
- sys.exit(0)
345
-
346
- q = " ".join(args)
347
- agent = BasicAgent()
348
- if q == "dev":
349
- agent.evaluate_random_questions()
350
- else:
351
- print(agent(q))