dawid-lorek commited on
Commit
4695b90
·
verified ·
1 Parent(s): 42cbc8d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +187 -97
agent.py CHANGED
@@ -1,13 +1,155 @@
1
  import os
 
2
  import requests
3
  import mimetypes
 
 
4
  from openai import OpenAI
5
  from duckduckgo_search import DDGS
6
  from PIL import Image
7
  import pytesseract
8
- import io
9
  import openpyxl
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class GaiaAgent:
12
  def __init__(self):
13
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
@@ -15,50 +157,43 @@ class GaiaAgent:
15
  "You are a top-tier research assistant for the GAIA benchmark. "
16
  "You analyze documents, reason step by step, and always provide a single, concise, and correct answer. "
17
  "If a file is provided, extract all relevant information. Use only information from the question and file. "
18
- "Always output only 'Final Answer: <answer>' as the last line, no explanation after."
 
 
 
19
  )
20
- self.api_url = "https://agents-course-unit4-scoring.hf.space"
21
-
22
- def fetch_file(self, task_id: str):
23
- try:
24
- url = f"{self.api_url}/files/{task_id}"
25
- resp = requests.get(url, timeout=15)
26
- resp.raise_for_status()
27
- content_type = resp.headers.get("Content-Type", "")
28
- return resp.content, content_type
29
- except Exception as e:
30
- return None, None
31
-
32
- def ocr_image(self, img_bytes):
33
- try:
34
- img = Image.open(io.BytesIO(img_bytes))
35
- return pytesseract.image_to_string(img)
36
- except Exception as e:
37
- return "[ERROR: Unable to OCR image]"
38
-
39
- def read_excel(self, file_bytes):
40
- try:
41
- wb = openpyxl.load_workbook(io.BytesIO(file_bytes), data_only=True)
42
- sheet = wb.active
43
- rows = list(sheet.iter_rows(values_only=True))
44
- text = "\n".join(["\t".join(str(cell) if cell is not None else "" for cell in row) for row in rows])
45
- return text
46
- except Exception as e:
47
- return "[ERROR: Unable to read Excel file]"
48
-
49
- def web_search(self, query, max_results=3):
50
- try:
51
- ddgs = DDGS()
52
- results = ddgs.text(query)
53
- summaries = []
54
- for i, r in enumerate(results):
55
- if i >= max_results: break
56
- summaries.append(f"{r['title']}: {r['body']}")
57
- return "\n".join(summaries)
58
- except Exception as e:
59
- return f"[ERROR: Web search failed: {e}]"
60
-
61
- def call_llm(self, prompt):
62
  response = self.client.chat.completions.create(
63
  model="gpt-4o",
64
  messages=[
@@ -66,58 +201,13 @@ class GaiaAgent:
66
  {"role": "user", "content": prompt}
67
  ],
68
  temperature=0.0,
69
- max_tokens=1024,
70
- )
71
- return response.choices[0].message.content.strip()
72
-
73
- def parse_final_answer(self, text):
74
- for line in reversed(text.splitlines()):
75
- if "Final Answer:" in line:
76
- return line.replace("Final Answer:", "").strip()
77
- # fallback
78
- return text.strip()
79
-
80
- def __call__(self, question: str, task_id: str = None) -> str:
81
- file_context = ""
82
- file_text = ""
83
- file_type = None
84
-
85
- # Step 1: Download and process file if provided
86
- if task_id:
87
- file_bytes, content_type = self.fetch_file(task_id)
88
- if not file_bytes or not content_type:
89
- file_context = "[ERROR: Could not download file]"
90
- elif "image" in content_type:
91
- file_text = self.ocr_image(file_bytes)
92
- file_context = f"Extracted text from image:\n{file_text}\n"
93
- elif "spreadsheet" in content_type or "excel" in content_type or task_id.endswith(".xlsx"):
94
- file_text = self.read_excel(file_bytes)
95
- file_context = f"Extracted text from Excel:\n{file_text}\n"
96
- elif "text" in content_type or "csv" in content_type or "json" in content_type:
97
- file_text = file_bytes.decode(errors="ignore")[:6000]
98
- file_context = f"File content:\n{file_text}\n"
99
- else:
100
- file_context = "[Unsupported or unknown file type]\n"
101
-
102
- # Step 2: Use web search for open-domain/factual questions
103
- # Basic heuristics: if the question is about a person, place, number, award, year, etc., try a search
104
- search_needed = False
105
- search_keywords = ["who", "what", "when", "where", "name", "number", "how many", "first", "last", "award", "recipient"]
106
- if any(kw in question.lower() for kw in search_keywords):
107
- search_results = self.web_search(question)
108
- if search_results and "ERROR" not in search_results:
109
- file_context += f"\nHere are relevant web search results:\n{search_results}\n"
110
- search_needed = True
111
-
112
- # Step 3: Build LLM prompt
113
- prompt = (
114
- f"{self.instructions}\n\n"
115
- f"{file_context}"
116
- f"Question: {question}\n"
117
- "Show your reasoning step by step, then provide the final answer as 'Final Answer: <answer>'."
118
  )
119
- llm_response = self.call_llm(prompt)
120
- answer = self.parse_final_answer(llm_response)
 
121
 
122
- # Step 4: Enforce strict output: only final answer, no extra lines
123
- return answer
 
 
 
1
  import os
2
+ import io
3
  import requests
4
  import mimetypes
5
+ import subprocess
6
+ import tempfile
7
  from openai import OpenAI
8
  from duckduckgo_search import DDGS
9
  from PIL import Image
10
  import pytesseract
 
11
  import openpyxl
12
 
13
+ try:
14
+ import whisper
15
+ except ImportError:
16
+ whisper = None
17
+
18
+ try:
19
+ import pdfplumber
20
+ except ImportError:
21
+ pdfplumber = None
22
+
23
+ AGENT_API_URL = "https://agents-course-unit4-scoring.hf.space"
24
+
25
+ def safe_strip(text):
26
+ if not text:
27
+ return ""
28
+ if isinstance(text, bytes):
29
+ text = text.decode(errors="ignore")
30
+ return str(text).replace("\r", "").strip()
31
+
32
+ def parse_final_answer(text):
33
+ """
34
+ Extracts only the final answer from an LLM reply, no explanations, no 'Final Answer:' prefix
35
+ """
36
+ for line in reversed(text.splitlines()):
37
+ if "Final Answer:" in line:
38
+ return line.split("Final Answer:")[-1].strip()
39
+ return safe_strip(text.splitlines()[-1])
40
+
41
+ def run_web_search(query, max_results=3):
42
+ try:
43
+ ddgs = DDGS()
44
+ results = ddgs.text(query)
45
+ for i, r in enumerate(results):
46
+ if i >= max_results:
47
+ break
48
+ if r.get('body'):
49
+ return r['body']
50
+ elif r.get('title'):
51
+ return r['title']
52
+ return ""
53
+ except Exception:
54
+ return ""
55
+
56
+ def fetch_file(task_id):
57
+ url = f"{AGENT_API_URL}/files/{task_id}"
58
+ try:
59
+ resp = requests.get(url, timeout=30)
60
+ resp.raise_for_status()
61
+ content_type = resp.headers.get("Content-Type", "")
62
+ return resp.content, content_type
63
+ except Exception:
64
+ return None, None
65
+
66
+ def ocr_image(img_bytes):
67
+ try:
68
+ img = Image.open(io.BytesIO(img_bytes))
69
+ return safe_strip(pytesseract.image_to_string(img))
70
+ except Exception:
71
+ return ""
72
+
73
+ def read_excel(file_bytes):
74
+ try:
75
+ wb = openpyxl.load_workbook(io.BytesIO(file_bytes), data_only=True)
76
+ sheet = wb.active
77
+ rows = list(sheet.iter_rows(values_only=True))
78
+ text = "\n".join(["\t".join(str(cell) if cell is not None else "" for cell in row) for row in rows])
79
+ return safe_strip(text)
80
+ except Exception:
81
+ return ""
82
+
83
+ def read_pdf(file_bytes):
84
+ if not pdfplumber:
85
+ return ""
86
+ try:
87
+ with pdfplumber.open(io.BytesIO(file_bytes)) as pdf:
88
+ return safe_strip("\n".join(page.extract_text() or "" for page in pdf.pages))
89
+ except Exception:
90
+ return ""
91
+
92
+ def transcribe_audio(audio_bytes):
93
+ if not whisper:
94
+ return ""
95
+ try:
96
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as tmpfile:
97
+ tmpfile.write(audio_bytes)
98
+ tmpfile.flush()
99
+ model = whisper.load_model("base")
100
+ result = model.transcribe(tmpfile.name)
101
+ return safe_strip(result.get("text", ""))
102
+ except Exception:
103
+ return ""
104
+
105
+ def transcribe_youtube_audio(youtube_url):
106
+ """
107
+ Download audio from YouTube, transcribe using whisper
108
+ """
109
+ if not whisper:
110
+ return ""
111
+ try:
112
+ with tempfile.TemporaryDirectory() as tmpdir:
113
+ audio_path = os.path.join(tmpdir, "audio.mp3")
114
+ cmd = [
115
+ "yt-dlp", "-f", "bestaudio[ext=m4a]/bestaudio/best",
116
+ "--extract-audio", "--audio-format", "mp3",
117
+ "-o", audio_path, youtube_url
118
+ ]
119
+ subprocess.run(cmd, check=True, capture_output=True)
120
+ model = whisper.load_model("base")
121
+ result = model.transcribe(audio_path)
122
+ return safe_strip(result.get("text", ""))
123
+ except Exception:
124
+ return ""
125
+
126
+ def extract_file_text(file_bytes, content_type, task_id=""):
127
+ # Images
128
+ if "image" in content_type:
129
+ return ocr_image(file_bytes)
130
+ # Excel
131
+ if "spreadsheet" in content_type or "excel" in content_type or task_id.endswith(".xlsx"):
132
+ return read_excel(file_bytes)
133
+ # PDF
134
+ if "pdf" in content_type or task_id.endswith(".pdf"):
135
+ return read_pdf(file_bytes)
136
+ # Audio
137
+ if "audio" in content_type or task_id.endswith(".mp3") or task_id.endswith(".wav"):
138
+ return transcribe_audio(file_bytes)
139
+ # Text, CSV, JSON
140
+ if "text" in content_type or "csv" in content_type or "json" in content_type or task_id.endswith(".csv") or task_id.endswith(".json") or task_id.endswith(".txt"):
141
+ return safe_strip(file_bytes[:10000])
142
+ return ""
143
+
144
+ def guess_youtube_link(question):
145
+ # If the question mentions YouTube or a video link, try to extract it
146
+ import re
147
+ matches = re.findall(r"(https?://[^\s]+)", question)
148
+ for url in matches:
149
+ if "youtube.com" in url or "youtu.be" in url:
150
+ return url
151
+ return None
152
+
153
  class GaiaAgent:
154
  def __init__(self):
155
  self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
157
  "You are a top-tier research assistant for the GAIA benchmark. "
158
  "You analyze documents, reason step by step, and always provide a single, concise, and correct answer. "
159
  "If a file is provided, extract all relevant information. Use only information from the question and file. "
160
+ "If the question refers to a video/audio file or YouTube link, always try to transcribe it. "
161
+ "If you need additional facts, summarize web search results provided. "
162
+ "Never apologize, never say you are unable, never output placeholders. "
163
+ "Always output the answer only—no explanations, no extra text."
164
  )
165
+
166
+ def __call__(self, question: str, task_id: str = None) -> str:
167
+ file_text = ""
168
+ web_context = ""
169
+ video_transcript = ""
170
+ prompt_parts = [self.instructions]
171
+ # 1. File handling (image, Excel, CSV, PDF, text, audio)
172
+ if task_id:
173
+ file_bytes, content_type = fetch_file(task_id)
174
+ if file_bytes and content_type:
175
+ file_text = extract_file_text(file_bytes, content_type, task_id)
176
+ if file_text:
177
+ prompt_parts.append(f"Here is the extracted file content:\n{file_text}\n")
178
+ # 2. YouTube/video handling (by URL in question)
179
+ youtube_url = guess_youtube_link(question)
180
+ if youtube_url:
181
+ transcript = transcribe_youtube_audio(youtube_url)
182
+ if transcript:
183
+ prompt_parts.append(f"Here is the transcript of the video:\n{transcript}\n")
184
+ # 3. Web search fallback for open-world/factoid questions or if no file info
185
+ search_keywords = [
186
+ "who", "what", "when", "where", "name", "number", "how many",
187
+ "first", "last", "award", "recipient", "code", "surname", "year", "album", "actor", "winner"
188
+ ]
189
+ if (not file_text and not youtube_url) or any(kw in question.lower() for kw in search_keywords):
190
+ search_results = run_web_search(question)
191
+ if search_results:
192
+ prompt_parts.append(f"Here are relevant web search results:\n{search_results}\n")
193
+ # 4. Compose prompt
194
+ prompt_parts.append(f"Question: {question}\nAnswer strictly and concisely.")
195
+ prompt = "\n".join(prompt_parts)
196
+ # 5. Call LLM
 
 
 
 
 
 
 
 
 
 
197
  response = self.client.chat.completions.create(
198
  model="gpt-4o",
199
  messages=[
 
201
  {"role": "user", "content": prompt}
202
  ],
203
  temperature=0.0,
204
+ max_tokens=512,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
+ raw_output = safe_strip(response.choices[0].message.content)
207
+ # 6. Only return the single-line answer, with no prefix
208
+ return parse_final_answer(raw_output)
209
 
210
+ # For compatibility with older interface (for "answer_question" import)
211
+ def answer_question(question, task_id=None):
212
+ agent = GaiaAgent()
213
+ return agent(question, task_id)