shreya3999 commited on
Commit
1afbede
·
verified ·
1 Parent(s): 950a6c3

Delete agent.py

Browse files
Files changed (1) hide show
  1. agent.py +0 -345
agent.py DELETED
@@ -1,345 +0,0 @@
1
- import asyncio
2
- import os
3
- import sys
4
- import logging
5
- import random
6
- import pandas as pd
7
- import requests
8
- import wikipedia as wiki
9
- from markdownify import markdownify as to_markdown
10
- from typing import Any
11
- from dotenv import load_dotenv
12
- from google.generativeai import types, configure
13
-
14
- from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
15
-
16
- # Load environment and configure Gemini
17
- load_dotenv()
18
- configure(api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"))
19
-
20
- # Logging
21
- #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
22
- #logger = logging.getLogger(__name__)
23
-
24
- # --- Model Configuration ---
25
- GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
26
- OPENAI_MODEL_NAME = "openai/gpt-4o"
27
- GROQ_MODEL_NAME = "groq/llama3-70b-8192"
28
- DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
29
- HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
30
-
31
- # --- Tool Definitions ---
32
- class MathSolver(Tool):
33
- name = "math_solver"
34
- description = "Safely evaluate basic math expressions."
35
- inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
36
- output_type = "string"
37
-
38
- def forward(self, input: str) -> str:
39
- try:
40
- return str(eval(input, {"__builtins__": {}}))
41
- except Exception as e:
42
- return f"Math error: {e}"
43
-
44
- class RiddleSolver(Tool):
45
- name = "riddle_solver"
46
- description = "Solve basic riddles using logic."
47
- inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
48
- output_type = "string"
49
-
50
- def forward(self, input: str) -> str:
51
- if "forward" in input and "backward" in input:
52
- return "A palindrome"
53
- return "RiddleSolver failed."
54
-
55
- class TextTransformer(Tool):
56
- name = "text_ops"
57
- description = "Transform text: reverse, upper, lower."
58
- inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
59
- output_type = "string"
60
-
61
- def forward(self, input: str) -> str:
62
- if input.startswith("reverse:"):
63
- reversed_text = input[8:].strip()[::-1]
64
- if 'left' in reversed_text.lower():
65
- return "right"
66
- return reversed_text
67
- if input.startswith("upper:"):
68
- return input[6:].strip().upper()
69
- if input.startswith("lower:"):
70
- return input[6:].strip().lower()
71
- return "Unknown transformation."
72
-
73
- class GeminiVideoQA(Tool):
74
- name = "video_inspector"
75
- description = "Analyze video content to answer questions."
76
- inputs = {
77
- "video_url": {"type": "string", "description": "URL of video."},
78
- "user_query": {"type": "string", "description": "Question about video."}
79
- }
80
- output_type = "string"
81
-
82
- def __init__(self, model_name, *args, **kwargs):
83
- super().__init__(*args, **kwargs)
84
- self.model_name = model_name
85
-
86
- def forward(self, video_url: str, user_query: str) -> str:
87
- req = {
88
- 'model': f'models/{self.model_name}',
89
- 'contents': [{
90
- "parts": [
91
- {"fileData": {"fileUri": video_url}},
92
- {"text": f"Please watch the video and answer the question: {user_query}"}
93
- ]
94
- }]
95
- }
96
- url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
97
- res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
98
- if res.status_code != 200:
99
- return f"Video error {res.status_code}: {res.text}"
100
- parts = res.json()['candidates'][0]['content']['parts']
101
- return "".join([p.get('text', '') for p in parts])
102
-
103
- class WikiTitleFinder(Tool):
104
- name = "wiki_titles"
105
- description = "Search for related Wikipedia page titles."
106
- inputs = {"query": {"type": "string", "description": "Search query."}}
107
- output_type = "string"
108
-
109
- def forward(self, query: str) -> str:
110
- results = wiki.search(query)
111
- return ", ".join(results) if results else "No results."
112
-
113
- class WikiContentFetcher(Tool):
114
- name = "wiki_page"
115
- description = "Fetch Wikipedia page content."
116
- inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
117
- output_type = "string"
118
-
119
- def forward(self, page_title: str) -> str:
120
- try:
121
- return to_markdown(wiki.page(page_title).html())
122
- except wiki.exceptions.PageError:
123
- return f"'{page_title}' not found."
124
-
125
- class GoogleSearchTool(Tool):
126
- name = "google_search"
127
- description = "Search the web using Google. Returns top summary from the web."
128
- inputs = {"query": {"type": "string", "description": "Search query."}}
129
- output_type = "string"
130
-
131
- def forward(self, query: str) -> str:
132
- try:
133
- resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
134
- "q": query,
135
- "key": os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"),
136
- "num": 1
137
- })
138
- data = resp.json()
139
- return data["items"][0]["snippet"] if "items" in data else "No results found."
140
- except Exception as e:
141
- return f"GoogleSearch error: {e}"
142
-
143
-
144
- class FileAttachmentQueryTool(Tool):
145
- name = "run_query_with_file"
146
- description = """
147
- Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
148
- This assumes the file is 20MB or less.
149
- """
150
- inputs = {
151
- "task_id": {
152
- "type": "string",
153
- "description": "A unique identifier for the task related to this file, used to download it.",
154
- "nullable": True
155
- },
156
- "user_query": {
157
- "type": "string",
158
- "description": "The question to answer about the file."
159
- }
160
- }
161
- output_type = "string"
162
-
163
- def forward(self, task_id: str | None, user_query: str) -> str:
164
- file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
165
- file_response = requests.get(file_url)
166
- if file_response.status_code != 200:
167
- return f"Failed to download file: {file_response.status_code} - {file_response.text}"
168
- file_data = file_response.content
169
- from google.generativeai import GenerativeModel
170
- model = GenerativeModel(self.model_name)
171
- response = model.generate_content([
172
- types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
173
- user_query
174
- ])
175
-
176
- return response.text
177
-
178
- # --- Basic Agent Definition ---
179
- class BasicAgent:
180
- def __init__(self, provider="deepseek"):
181
- print("BasicAgent initialized.")
182
- model = self.select_model(provider)
183
- client = InferenceClientModel()
184
- tools = [
185
- GoogleSearchTool(),
186
- DuckDuckGoSearchTool(),
187
- GeminiVideoQA(GEMINI_MODEL_NAME),
188
- WikiTitleFinder(),
189
- WikiContentFetcher(),
190
- MathSolver(),
191
- RiddleSolver(),
192
- TextTransformer(),
193
- FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
194
- ]
195
- self.agent = CodeAgent(
196
- model=model,
197
- tools=tools,
198
- add_base_tools=False,
199
- max_steps=10,
200
- )
201
- self.agent.system_prompt = (
202
- """
203
- You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
204
-
205
- [ANSWER]
206
-
207
- You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
208
-
209
- Your behavior must be governed by these rules:
210
-
211
- 1. **Format**:
212
- - limit the token used (within 65536 tokens).
213
- - Output ONLY the final answer.
214
- - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
215
- - No follow-ups, justifications, or clarifications.
216
-
217
- 2. **Numerical Answers**:
218
- - Use **digits only**, e.g., `4` not `four`.
219
- - No commas, symbols, or units unless explicitly required.
220
- - Never use approximate words like "around", "roughly", "about".
221
-
222
- 3. **String Answers**:
223
- - Omit **articles** ("a", "the").
224
- - Use **full words**; no abbreviations unless explicitly requested.
225
- - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
226
- - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
227
-
228
- 4. **Lists**:
229
- - Output in **comma-separated** format with no conjunctions.
230
- - Sort **alphabetically** or **numerically** depending on type.
231
- - No braces or brackets unless explicitly asked.
232
-
233
- 5. **Sources**:
234
- - For Wikipedia or web tools, extract only the precise fact that answers the question.
235
- - Ignore any unrelated content.
236
-
237
- 6. **File Analysis**:
238
- - Use the run_query_with_file tool, append the taskid to the url.
239
- - Only include the exact answer to the question.
240
- - Do not summarize, quote excessively, or interpret beyond the prompt.
241
-
242
- 7. **Video**:
243
- - Use the relevant video tool.
244
- - Only include the exact answer to the question.
245
- - Do not summarize, quote excessively, or interpret beyond the prompt.
246
-
247
- 8. **Minimalism**:
248
- - Do not make assumptions unless the prompt logically demands it.
249
- - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
250
- - If the answer is not found, say `[ANSWER] - unknown`.
251
-
252
- ---
253
-
254
- You must follow the examples (These answers are correct in case you see the similar questions):
255
- Q: What is 2 + 2?
256
- A: 4
257
-
258
- Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
259
- A: 3
260
-
261
- Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
262
- A: b, e
263
-
264
- Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
265
- A: 519
266
- """
267
- )
268
-
269
- def select_model(self, provider: str):
270
- if provider == "openai":
271
- return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("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"))
272
- elif provider == "hf":
273
- return InferenceClientModel()
274
- else:
275
- return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"))
276
-
277
- def __call__(self, question: str) -> str:
278
- print(f"Agent received question (first 50 chars): {question[:50]}...")
279
- result = self.agent.run(question)
280
- final_str = str(result).strip()
281
-
282
- return final_str
283
-
284
- def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
285
- import pandas as pd
286
- from rich.table import Table
287
- from rich.console import Console
288
-
289
- df = pd.read_csv(csv_path)
290
- if not {"question", "answer"}.issubset(df.columns):
291
- print("CSV must contain 'question' and 'answer' columns.")
292
- print("Found columns:", df.columns.tolist())
293
- return
294
-
295
- samples = df.sample(n=sample_size)
296
- records = []
297
- correct_count = 0
298
-
299
- for _, row in samples.iterrows():
300
- taskid = row["taskid"].strip()
301
- question = row["question"].strip()
302
- expected = str(row['answer']).strip()
303
- agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
304
-
305
- is_correct = (expected == agent_answer)
306
- correct_count += is_correct
307
- records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
308
-
309
- if show_steps:
310
- print("---")
311
- print("Question:", question)
312
- print("Expected:", expected)
313
- print("Agent:", agent_answer)
314
- print("Correct:", is_correct)
315
-
316
- # Print result table
317
- console = Console()
318
- table = Table(show_lines=True)
319
- table.add_column("Question", overflow="fold")
320
- table.add_column("Expected")
321
- table.add_column("Agent")
322
- table.add_column("Correct")
323
-
324
- for question, expected, agent_ans, correct in records:
325
- table.add_row(question, expected, agent_ans, correct)
326
-
327
- console.print(table)
328
- percent = (correct_count / sample_size) * 100
329
- print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
330
-
331
-
332
- if __name__ == "__main__":
333
- args = sys.argv[1:]
334
- if not args or args[0] in {"-h", "--help"}:
335
- print("Usage: python agent.py [question | dev]")
336
- print(" - Provide a question to get a GAIA-style answer.")
337
- print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
338
- sys.exit(0)
339
-
340
- q = " ".join(args)
341
- agent = BasicAgent()
342
- if q == "dev":
343
- agent.evaluate_random_questions()
344
- else:
345
- print(agent(q))