dawid-lorek commited on
Commit
188a166
·
verified ·
1 Parent(s): 99134fe

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +162 -265
agent.py CHANGED
@@ -1,276 +1,173 @@
1
  import os
2
- import io
3
  import re
4
- import requests
5
- import tempfile
6
- import subprocess
7
- from openai import OpenAI
8
- from duckduckgo_search import DDGS
9
- from PIL import Image
10
- import pytesseract
11
- import openpyxl
12
-
13
- try:
14
- import whisper
15
- except ImportError:
16
- whisper = None
17
-
18
- try:
19
- import pdfplumber
20
- except ImportError:
21
- pdfplumber = None
22
-
23
- AGENT_API_URL = "https://agents-course-unit4-scoring.hf.space"
24
-
25
- def safe_strip(text):
26
- if not text:
27
- return ""
28
- if isinstance(text, bytes):
29
- text = text.decode(errors="ignore")
30
- return str(text).replace("\r", "").strip()
31
-
32
- def run_web_search(query, max_results=3):
33
  try:
34
- ddgs = DDGS()
35
- results = ddgs.text(query)
36
- bodies = []
37
- for i, r in enumerate(results):
38
- if i >= max_results:
39
- break
40
- if r.get('body'):
41
- bodies.append(r['body'])
42
- elif r.get('title'):
43
- bodies.append(r['title'])
44
- return "\n".join(bodies)
45
- except Exception:
46
- return ""
47
-
48
- def fetch_file(task_id):
49
- url = f"{AGENT_API_URL}/files/{task_id}"
50
- try:
51
- resp = requests.get(url, timeout=30)
52
- resp.raise_for_status()
53
- content_type = resp.headers.get("Content-Type", "")
54
- return resp.content, content_type
55
- except Exception:
56
- return None, None
57
-
58
- def ocr_image(img_bytes):
59
- try:
60
- img = Image.open(io.BytesIO(img_bytes))
61
- return safe_strip(pytesseract.image_to_string(img))
62
- except Exception:
63
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- def read_excel(file_bytes):
 
 
 
 
66
  try:
67
- wb = openpyxl.load_workbook(io.BytesIO(file_bytes), data_only=True)
68
- sheet = wb.active
69
- rows = list(sheet.iter_rows(values_only=True))
70
- text = "\n".join(["\t".join(str(cell) if cell is not None else "" for cell in row) for row in rows])
71
- return safe_strip(text)
72
- except Exception:
73
- return ""
74
-
75
- def read_pdf(file_bytes):
76
- if not pdfplumber:
77
- return ""
78
  try:
79
- with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
80
- return safe_strip("\n".join(page.extract_text() or "" for page in pdf.pages))
81
- except Exception:
82
- return ""
83
-
84
- def transcribe_audio(audio_bytes):
85
- if not whisper:
86
- return ""
87
- try:
88
- with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmpfile:
89
- tmpfile.write(audio_bytes)
90
- tmpfile.flush()
91
- model = whisper.load_model("base")
92
- result = model.transcribe(tmpfile.name)
93
- return safe_strip(result.get("text", ""))
94
- except Exception:
95
- return ""
96
-
97
- def transcribe_youtube_audio(youtube_url):
98
- if not whisper:
99
- return ""
100
  try:
 
 
101
  with tempfile.TemporaryDirectory() as tmpdir:
102
- audio_path = os.path.join(tmpdir, "audio.mp3")
103
- cmd = [
104
- "yt-dlp", "-f", "bestaudio[ext=m4a]/bestaudio/best",
105
- "--extract-audio", "--audio-format", "mp3",
106
- "-o", audio_path, youtube_url
107
- ]
108
- subprocess.run(cmd, check=True, capture_output=True)
109
  model = whisper.load_model("base")
110
  result = model.transcribe(audio_path)
111
- return safe_strip(result.get("text", ""))
112
- except Exception:
113
- return ""
114
-
115
- def extract_file_text(file_bytes, content_type, task_id=""):
116
- if "image" in content_type:
117
- return ocr_image(file_bytes)
118
- if "spreadsheet" in content_type or "excel" in content_type or task_id.endswith(".xlsx"):
119
- return read_excel(file_bytes)
120
- if "pdf" in content_type or task_id.endswith(".pdf"):
121
- return read_pdf(file_bytes)
122
- if "audio" in content_type or task_id.endswith(".mp3") or task_id.endswith(".wav"):
123
- return transcribe_audio(file_bytes)
124
- if "text" in content_type or "csv" in content_type or "json" in content_type or task_id.endswith(".csv") or task_id.endswith(".json") or task_id.endswith(".txt"):
125
- return safe_strip(file_bytes[:10000])
126
- return ""
127
-
128
- def guess_youtube_link(question):
129
- matches = re.findall(r"(https?://[^\s]+)", question)
130
- for url in matches:
131
- if "youtube.com" in url or "youtu.be" in url:
132
- return url
133
- return None
134
-
135
- def format_gaia_answer(answer, question=None):
136
- if not answer or not isinstance(answer, str):
137
- return ""
138
- answer = re.sub(r"(?i)i'?m sorry[,\.]?|i cannot|i can't|unable to|please provide.*|information not available|I can't assist.*|I'm unable.*|process the file directly", "", answer)
139
- answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
140
- if answer.startswith('"') and answer.endswith('"'):
141
- answer = answer[1:-1]
142
- if answer.startswith('[') and answer.endswith(']'):
143
- answer = answer[1:-1]
144
- if not re.match(r'^[A-Za-z]+\.$', answer):
145
- answer = re.sub(r'\.$', '', answer)
146
-
147
- if question:
148
- # Pure number answers
149
- if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
150
- num_match = re.search(r'(\$?\d[\d,\.]*)', answer)
151
- if num_match:
152
- return num_match.group(1).replace(',', '')
153
-
154
- if re.search(r'first name', question, re.I):
155
- return answer.strip().split()[0]
156
- if re.search(r'surname', question, re.I):
157
- return answer.strip().split()[-1]
158
- if re.search(r'city', question, re.I):
159
- return answer.strip().split()[0]
160
- if re.search(r'IOC country code|award number|NASA', question, re.I):
161
- code_match = re.search(r'[A-Z0-9]{3,}', answer)
162
- if code_match:
163
- return code_match.group(0)
164
- if 'algebraic notation' in question or 'chess' in question:
165
- move_match = re.search(r'[A-Za-z0-9]+[#\+]?$', answer)
166
- if move_match:
167
- return move_match.group(0)
168
- if "what does teal'c say" in question.lower():
169
- qmatch = re.search(r'"(Indeed\.)"', answer)
170
- if qmatch:
171
- return qmatch.group(1)
172
- if "Indeed." in answer:
173
- return "Indeed."
174
- return answer
175
- if re.search(r'list|comma.*separated|page numbers', question, re.I):
176
- items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
177
- items = [item for item in items if item and not re.match(r'(and|or|to|with|for|a|the)$', item)]
178
- if 'page numbers' in question:
179
- nums = [int(x) for x in re.findall(r'\d+', answer)]
180
- return ', '.join(str(n) for n in sorted(nums))
181
- if 'ingredient' in question or 'vegetable' in question or 'grocery' in question:
182
- merged = []
183
- skip = False
184
- for i, item in enumerate(items):
185
- if skip:
186
- skip = False
187
- continue
188
- if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
189
- merged.append(f"{item} {items[i+1]}")
190
- skip = True
191
- else:
192
- merged.append(item)
193
- merged = sorted(set(merged))
194
- return ', '.join(merged)
195
- return ', '.join(items)
196
- if re.search(r'pitcher.*before.*after', question, re.I):
197
- names = re.findall(r'\b[A-Z][a-z]+', answer)
198
- return ', '.join(names[:2])
199
-
200
- return answer.strip().rstrip('.').strip()
201
-
202
- class GaiaAgent:
203
- def __init__(self):
204
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
205
- self.instructions = (
206
- "You are a top-tier research assistant for the GAIA benchmark. "
207
- "You analyze documents, reason step by step, and always provide a single, concise, and correct answer. "
208
- "If a file is provided, extract all relevant information. Use only information from the question and file. "
209
- "If the question refers to a video/audio file or YouTube link, always try to transcribe it. "
210
- "If you need additional facts, summarize web search results provided. "
211
- "Never apologize, never say you are unable, never output placeholders. "
212
- "Always output the answer only—no explanations, no extra text."
213
- )
214
-
215
- def __call__(self, question: str, task_id: str = None) -> str:
216
- file_text = ""
217
- prompt_parts = [self.instructions]
218
- # 1. File (image, Excel, etc)
219
- if task_id:
220
- file_bytes, content_type = fetch_file(task_id)
221
- if file_bytes and content_type:
222
- file_text = extract_file_text(file_bytes, content_type, task_id)
223
- if file_text:
224
- prompt_parts.append(f"Here is the extracted file content:\n{file_text}\n")
225
- # 2. YouTube/video
226
- youtube_url = guess_youtube_link(question)
227
- if youtube_url:
228
- transcript = transcribe_youtube_audio(youtube_url)
229
- if transcript:
230
- prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
231
- # 3. Web search for open facts
232
- search_needed = not file_text and not youtube_url
233
- search_keywords = [
234
- "who", "what", "when", "where", "name", "number", "how many",
235
- "first", "last", "award", "recipient", "code", "surname", "year", "album", "actor", "winner"
236
- ]
237
- if search_needed or any(kw in question.lower() for kw in search_keywords):
238
- search_results = run_web_search(question)
239
- if search_results:
240
- prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
241
- prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
242
- prompt = "\n".join(prompt_parts)
243
- response = self.client.chat.completions.create(
244
- model="gpt-4o",
245
- messages=[
246
- {"role": "system", "content": self.instructions},
247
- {"role": "user", "content": prompt}
248
- ],
249
- temperature=0.0,
250
- max_tokens=512,
251
- )
252
- raw_output = safe_strip(response.choices[0].message.content)
253
- formatted = format_gaia_answer(raw_output, question)
254
- if not formatted or formatted.lower() in ('', 'unknown', 'none', 'n/a') or 'apolog' in formatted.lower():
255
- web_info = run_web_search(question)
256
- if web_info:
257
- prompt2 = (
258
- f"{self.instructions}\n\n"
259
- f"Here are relevant web search results:\n{web_info}\n"
260
- f"Question: {question}\nAnswer strictly and concisely."
261
- )
262
- response2 = self.client.chat.completions.create(
263
- model="gpt-4o",
264
- messages=[
265
- {"role": "system", "content": self.instructions},
266
- {"role": "user", "content": prompt2}
267
- ],
268
- temperature=0.0,
269
- max_tokens=256,
270
- )
271
- formatted = format_gaia_answer(safe_strip(response2.choices[0].message.content), question)
272
- return formatted
273
-
274
- def answer_question(question, task_id=None):
275
- agent = GaiaAgent()
276
- return agent(question, task_id)
 
1
  import os
2
+ import asyncio
3
  import re
4
+ from typing import Any
5
+
6
+ from llama_index.llms.openai import OpenAI
7
+ from llama_index.core.agent.react import ReActAgent
8
+ from llama_index.core.agent.workflow import AgentWorkflow
9
+ from llama_index.core.tools import FunctionTool, ToolMetadata
10
+
11
+ # Tool: DuckDuckGo Web Search
12
+ from llama_index.tools.duckduckgo import DuckDuckGoSearchTool
13
+
14
+ # Tool: Python code eval (for simple code/number/output questions)
15
+ def eval_python_code(code: str) -> str:
16
+ """
17
+ Evaluate simple Python code and return result as string.
18
+ Use for 'What is the output of this code?' or math.
19
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
+ # Only eval expressions (NOT exec for safety!)
22
+ return str(eval(code, {"__builtins__": {}}))
23
+ except Exception as e:
24
+ return f"ERROR: {e}"
25
+
26
+ # Tool: Strict output formatting
27
+ def format_gaia_answer(answer: str, question: str = "") -> str:
28
+ """Postprocess: GAIA strict answer format enforcement."""
29
+ if not answer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return ""
31
+ # Remove quotes/brackets/periods, apologies, "Final Answer:"
32
+ answer = re.sub(r'(?i)final answer:?\s*', '', answer).strip()
33
+ answer = re.sub(r'(?i)i(\'?m| cannot| can\'t| unable to| apologize| not available|process the file).*', '', answer).strip()
34
+ if answer.startswith('"') and answer.endswith('"'): answer = answer[1:-1]
35
+ if answer.startswith('[') and answer.endswith(']'): answer = answer[1:-1]
36
+ if not re.match(r'^[A-Za-z]+\.$', answer): answer = re.sub(r'\.$', '', answer)
37
+ # Numeric
38
+ if re.search(r'how many|number of|at bats|total sales|albums|output.*python|highest number', question, re.I):
39
+ num = re.search(r'(\$?\d[\d,\.]*)', answer)
40
+ if num: return num.group(1).replace(',', '')
41
+ # Surname/first name/code/city
42
+ if 'first name' in question: return answer.split()[0]
43
+ if 'surname' in question: return answer.split()[-1]
44
+ if 'city' in question: return answer.split()[0]
45
+ if re.search(r'IOC country code|award number|NASA', question, re.I):
46
+ code = re.search(r'[A-Z0-9]{3,}', answer)
47
+ if code: return code.group(0)
48
+ if re.search(r'list|comma.*separated|page numbers', question, re.I):
49
+ items = [x.strip('",.').lower() for x in re.split(r'[,\n]', answer) if x.strip()]
50
+ if 'page numbers' in question:
51
+ nums = [int(x) for x in re.findall(r'\d+', answer)]
52
+ return ', '.join(str(n) for n in sorted(nums))
53
+ if 'ingredient' in question or 'vegetable' in question:
54
+ merged = []
55
+ skip = False
56
+ for i, item in enumerate(items):
57
+ if skip: skip = False; continue
58
+ if i+1 < len(items) and item in ['sweet', 'green', 'lemon', 'ripe', 'whole', 'fresh']:
59
+ merged.append(f"{item} {items[i+1]}")
60
+ skip = True
61
+ else: merged.append(item)
62
+ merged = sorted(set(merged))
63
+ return ', '.join(merged)
64
+ return ', '.join(items)
65
+ return answer.strip().rstrip('.').strip()
66
 
67
+ # Tool: OCR for images (incl. chessboards/screenshots)
68
+ def ocr_image(file_path: str) -> str:
69
+ """Extract text from image file."""
70
+ from PIL import Image
71
+ import pytesseract
72
  try:
73
+ img = Image.open(file_path)
74
+ return pytesseract.image_to_string(img)
75
+ except Exception as e:
76
+ return f"ERROR: {e}"
77
+
78
+ # Tool: Audio transcription (Whisper)
79
+ def transcribe_audio(file_path: str) -> str:
80
+ """Transcribe audio file with Whisper."""
 
 
 
81
  try:
82
+ import whisper
83
+ model = whisper.load_model("base")
84
+ result = model.transcribe(file_path)
85
+ return result.get("text", "")
86
+ except Exception as e:
87
+ return f"ERROR: {e}"
88
+
89
+ # Tool: YouTube video transcription
90
+ def transcribe_youtube(url: str) -> str:
91
+ """Download and transcribe a YouTube video (audio only)."""
92
+ import tempfile, os
 
 
 
 
 
 
 
 
 
 
93
  try:
94
+ import whisper
95
+ import yt_dlp
96
  with tempfile.TemporaryDirectory() as tmpdir:
97
+ ydl_opts = {'format': 'bestaudio/best', 'outtmpl': os.path.join(tmpdir, 'audio.%(ext)s')}
98
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
99
+ ydl.download([url])
100
+ audio_path = [os.path.join(tmpdir, f) for f in os.listdir(tmpdir) if f.startswith("audio")][0]
 
 
 
101
  model = whisper.load_model("base")
102
  result = model.transcribe(audio_path)
103
+ return result.get("text", "")
104
+ except Exception as e:
105
+ return f"ERROR: {e}"
106
+
107
+ # ---- LlamaIndex agent and workflow setup ----
108
+
109
+ # 1. Initialize LLM
110
+ llm = OpenAI(model="gpt-4o", api_key=os.environ.get("OPENAI_API_KEY"))
111
+
112
+ # 2. Register tools
113
+ tools = [
114
+ DuckDuckGoSearchTool(),
115
+ FunctionTool.from_defaults(
116
+ eval_python_code,
117
+ name="python_eval",
118
+ description="Evaluate simple Python code and return result as string. Use for math or code output."
119
+ ),
120
+ FunctionTool.from_defaults(
121
+ ocr_image,
122
+ name="ocr_image",
123
+ description="Extract text from an image file (provide file path)."
124
+ ),
125
+ FunctionTool.from_defaults(
126
+ transcribe_audio,
127
+ name="transcribe_audio",
128
+ description="Transcribe an audio file using Whisper (provide file path)."
129
+ ),
130
+ FunctionTool.from_defaults(
131
+ transcribe_youtube,
132
+ name="transcribe_youtube",
133
+ description="Download a YouTube video, extract and transcribe its audio using Whisper."
134
+ ),
135
+ FunctionTool.from_defaults(
136
+ format_gaia_answer,
137
+ name="format_gaia_answer",
138
+ description="Postprocess and enforce strict GAIA format on answers given a question."
139
+ ),
140
+ ]
141
+
142
+ # 3. Agent setup (ReAct, so can reason with tools)
143
+ agent = ReActAgent.from_tools(
144
+ tools=tools,
145
+ llm=llm,
146
+ system_prompt="You are a helpful GAIA benchmark agent. For every question, use the best tools available and always return only the final answer in the strict GAIA-required format—never explain, never apologize.",
147
+ verbose=False
148
+ )
149
+
150
+ # 4. Async entrypoint, suitable for HuggingFace Spaces or Gradio
151
+ async def answer_question(question: str, task_id: str = None, file_path: str = None) -> str:
152
+ """
153
+ Main async function for the agent.
154
+ Passes the question and uses tools as needed.
155
+ - task_id: for future use, if you want to fetch files from a remote API.
156
+ - file_path: if a file (image, audio, etc) is present locally, pass it.
157
+ """
158
+ # Example: if you want to always try OCR/audio on a file before reasoning, you could do:
159
+ # If question contains "image" or "chess" and file_path is set, run OCR first
160
+ if file_path and any(word in question.lower() for word in ["image", "chess", "screenshot"]):
161
+ ocr_text = ocr_image(file_path)
162
+ question = f"Extracted text from image: {ocr_text}\n\n{question}"
163
+ if file_path and any(word in question.lower() for word in ["audio", "mp3", "transcribe"]):
164
+ audio_text = transcribe_audio(file_path)
165
+ question = f"Transcribed audio: {audio_text}\n\n{question}"
166
+
167
+ # Run agent
168
+ result = await agent.achat(question)
169
+ return result.response
170
+
171
+ # Synchronous wrapper for legacy compat
172
+ def answer_question_sync(question: str, task_id: str = None, file_path: str = None) -> str:
173
+ return asyncio.run(answer_question(question, task_id, file_path))