Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| π Enhanced GAIA Agent - Full GAIA Benchmark Implementation | |
| Optimized for 30%+ performance on GAIA benchmark with complete API integration | |
| """ | |
| import os | |
| import re | |
| import json | |
| import base64 | |
| import logging | |
| import requests | |
| from typing import Dict, List, Any, Optional, Tuple | |
| from urllib.parse import urlparse, quote | |
| from io import BytesIO | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| from bs4 import BeautifulSoup | |
| # import markdownify # Removed for compatibility | |
| from huggingface_hub import InferenceClient | |
| import mimetypes | |
| import openpyxl | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| import subprocess | |
| import tempfile | |
| # Configure logging | |
| logging.basicConfig(filename='gaia_agent.log', level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s') | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # --- Tool/LLM Wrappers --- | |
| def llama3_chat(prompt): | |
| try: | |
| client = InferenceClient(provider="fireworks-ai", api_key=HF_TOKEN) | |
| completion = client.chat.completions.create( | |
| model="meta-llama/Llama-3.1-8B-Instruct", | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| logging.error(f"llama3_chat error: {e}") | |
| return f"LLM error: {e}" | |
| def mixtral_chat(prompt): | |
| try: | |
| client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN) | |
| completion = client.chat.completions.create( | |
| model="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| logging.error(f"mixtral_chat error: {e}") | |
| return f"LLM error: {e}" | |
| def extractive_qa(question, context): | |
| try: | |
| client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN) | |
| answer = client.question_answering( | |
| question=question, | |
| context=context, | |
| model="deepset/roberta-base-squad2", | |
| ) | |
| return answer["answer"] | |
| except Exception as e: | |
| logging.error(f"extractive_qa error: {e}") | |
| return f"QA error: {e}" | |
| def table_qa(query, table): | |
| try: | |
| client = InferenceClient(provider="hf-inference", api_key=HF_TOKEN) | |
| answer = client.table_question_answering( | |
| query=query, | |
| table=table, | |
| model="google/tapas-large-finetuned-wtq", | |
| ) | |
| return answer["answer"] | |
| except Exception as e: | |
| logging.error(f"table_qa error: {e}") | |
| return f"Table QA error: {e}" | |
| def asr_transcribe(audio_path): | |
| try: | |
| import torchaudio | |
| from transformers import pipeline | |
| asr = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") | |
| result = asr(audio_path) | |
| return result["text"] | |
| except Exception as e: | |
| logging.error(f"asr_transcribe error: {e}") | |
| return f"ASR error: {e}" | |
| def image_caption(image_path): | |
| try: | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from PIL import Image | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| raw_image = Image.open(image_path).convert('RGB') | |
| inputs = processor(raw_image, return_tensors="pt") | |
| out = model.generate(**inputs) | |
| return processor.decode(out[0], skip_special_tokens=True) | |
| except Exception as e: | |
| logging.error(f"image_caption error: {e}") | |
| return f"Image captioning error: {e}" | |
| def code_analysis(py_path): | |
| try: | |
| # Hardened: run code in subprocess with timeout and memory limit | |
| with open(py_path) as f: | |
| code = f.read() | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as tmp: | |
| tmp.write(code) | |
| tmp_path = tmp.name | |
| try: | |
| result = subprocess.run([ | |
| "python3", tmp_path | |
| ], capture_output=True, text=True, timeout=5) | |
| if result.returncode == 0: | |
| output = result.stdout.strip().split('\n') | |
| return output[-1] if output else '' | |
| else: | |
| logging.error(f"code_analysis subprocess error: {result.stderr}") | |
| return f"Code error: {result.stderr}" | |
| except subprocess.TimeoutExpired: | |
| logging.error("code_analysis timeout") | |
| return "Code execution timed out" | |
| finally: | |
| os.remove(tmp_path) | |
| except Exception as e: | |
| logging.error(f"code_analysis error: {e}") | |
| return f"Code analysis error: {e}" | |
| def youtube_video_qa(youtube_url, question): | |
| import subprocess | |
| import tempfile | |
| import os | |
| from transformers import pipeline | |
| try: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # Download video | |
| video_path = os.path.join(tmpdir, "video.mp4") | |
| cmd = ["yt-dlp", "-f", "mp4", "-o", video_path, youtube_url] | |
| subprocess.run(cmd, check=True) | |
| # Extract audio for ASR | |
| audio_path = os.path.join(tmpdir, "audio.mp3") | |
| cmd_audio = ["yt-dlp", "-f", "bestaudio", "--extract-audio", "--audio-format", "mp3", "-o", audio_path, youtube_url] | |
| subprocess.run(cmd_audio, check=True) | |
| # Transcribe audio | |
| asr = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") | |
| result = asr(audio_path) | |
| transcript = result["text"] | |
| # Extract frames for vision QA | |
| cap = cv2.VideoCapture(video_path) | |
| frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| frames = [] | |
| for i in range(0, frame_count, max(1, fps*5)): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
| frames.append(img) | |
| cap.release() | |
| # Object detection (YOLOv8) | |
| try: | |
| from ultralytics import YOLO | |
| yolo = YOLO("yolov8n.pt") | |
| detections = [] | |
| for img in frames: | |
| results = yolo(np.array(img)) | |
| for r in results: | |
| for c in r.boxes.cls: | |
| detections.append(yolo.model.names[int(c)]) | |
| detection_summary = {} | |
| for obj in detections: | |
| detection_summary[obj] = detection_summary.get(obj, 0) + 1 | |
| except Exception as e: | |
| logging.error(f"YOLOv8 error: {e}") | |
| detection_summary = {} | |
| # Image captioning (BLIP) | |
| try: | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| captions = [] | |
| for img in frames: | |
| inputs = processor(img, return_tensors="pt") | |
| out = model.generate(**inputs) | |
| captions.append(processor.decode(out[0], skip_special_tokens=True)) | |
| except Exception as e: | |
| logging.error(f"BLIP error: {e}") | |
| captions = [] | |
| # Aggregate and answer | |
| context = f"Transcript: {transcript}\nCaptions: {' | '.join(captions)}\nDetections: {detection_summary}" | |
| answer = extractive_qa(question, context) | |
| return answer | |
| except Exception as e: | |
| logging.error(f"YouTube video QA error: {e}") | |
| return f"Video analysis error: {e}" | |
| # --- Tool Registry --- | |
| TOOL_REGISTRY = { | |
| "llama3_chat": llama3_chat, | |
| "mixtral_chat": mixtral_chat, | |
| "extractive_qa": extractive_qa, | |
| "table_qa": table_qa, | |
| "asr_transcribe": asr_transcribe, | |
| "image_caption": image_caption, | |
| "code_analysis": code_analysis, | |
| "youtube_video_qa": youtube_video_qa, | |
| } | |
| class ModularGAIAAgent: | |
| """ | |
| Modular GAIA Agent: fetches questions from API, downloads files, routes to tools/LLMs, chains outputs, and formats GAIA-compliant answers. | |
| """ | |
| def __init__(self, api_url=DEFAULT_API_URL, tool_registry=TOOL_REGISTRY): | |
| self.api_url = api_url | |
| self.tools = tool_registry | |
| self.reasoning_trace = [] | |
| self.file_cache = set(os.listdir('.')) | |
| def fetch_questions(self, from_api=True, questions_path="Hugging Face Questions") -> List[Dict[str, Any]]: | |
| if from_api: | |
| r = requests.get(f"{self.api_url}/questions") | |
| r.raise_for_status() | |
| return r.json() | |
| else: | |
| with open(questions_path) as f: | |
| data = f.read() | |
| start = data.find("[") | |
| end = data.rfind("]") + 1 | |
| questions = json.loads(data[start:end]) | |
| return questions | |
| def download_file(self, file_id, file_name=None): | |
| if not file_name: | |
| file_name = file_id | |
| if file_name in self.file_cache: | |
| return file_name | |
| url = f"{self.api_url}/files/{file_id}" | |
| r = requests.get(url) | |
| if r.status_code == 200: | |
| with open(file_name, "wb") as f: | |
| f.write(r.content) | |
| self.file_cache.add(file_name) | |
| return file_name | |
| else: | |
| self.reasoning_trace.append(f"Failed to download file {file_id} (status {r.status_code})") | |
| return None | |
| def detect_file_type(self, file_name): | |
| ext = os.path.splitext(file_name)[-1].lower() | |
| if ext in ['.mp3', '.wav', '.flac']: | |
| return 'audio' | |
| elif ext in ['.png', '.jpg', '.jpeg', '.bmp']: | |
| return 'image' | |
| elif ext in ['.py']: | |
| return 'code' | |
| elif ext in ['.xlsx']: | |
| return 'excel' | |
| elif ext in ['.csv']: | |
| return 'csv' | |
| elif ext in ['.json']: | |
| return 'json' | |
| elif ext in ['.txt', '.md']: | |
| return 'text' | |
| else: | |
| return 'unknown' | |
| def analyze_file(self, file_name, file_type): | |
| if file_type == 'audio': | |
| transcript = self.tools['asr_transcribe'](file_name) | |
| self.reasoning_trace.append(f"Transcribed audio: {transcript[:100]}...") | |
| return transcript | |
| elif file_type == 'image': | |
| caption = self.tools['image_caption'](file_name) | |
| self.reasoning_trace.append(f"Image caption: {caption}") | |
| return caption | |
| elif file_type == 'code': | |
| result = self.tools['code_analysis'](file_name) | |
| self.reasoning_trace.append(f"Code analysis result: {result}") | |
| return result | |
| elif file_type == 'excel': | |
| wb = openpyxl.load_workbook(file_name) | |
| ws = wb.active | |
| data = list(ws.values) | |
| headers = data[0] | |
| table = [dict(zip(headers, row)) for row in data[1:]] | |
| self.reasoning_trace.append(f"Excel table loaded: {table[:2]}...") | |
| return table | |
| elif file_type == 'csv': | |
| df = pd.read_csv(file_name) | |
| table = df.to_dict(orient='records') | |
| self.reasoning_trace.append(f"CSV table loaded: {table[:2]}...") | |
| return table | |
| elif file_type == 'json': | |
| with open(file_name) as f: | |
| data = json.load(f) | |
| self.reasoning_trace.append(f"JSON loaded: {str(data)[:100]}...") | |
| return data | |
| elif file_type == 'text': | |
| with open(file_name) as f: | |
| text = f.read() | |
| self.reasoning_trace.append(f"Text loaded: {text[:100]}...") | |
| return text | |
| else: | |
| self.reasoning_trace.append(f"Unknown file type: {file_name}") | |
| return None | |
| def answer_question(self, question_obj): | |
| self.reasoning_trace = [] | |
| q = question_obj["question"] | |
| file_name = question_obj.get("file_name", "") | |
| file_content = None | |
| file_type = None | |
| # YouTube video question detection | |
| if "youtube.com" in q or "youtu.be" in q: | |
| url = None | |
| for word in q.split(): | |
| if "youtube.com" in word or "youtu.be" in word: | |
| url = word.strip().strip(',') | |
| break | |
| if url: | |
| answer = self.tools['youtube_video_qa'](url, q) | |
| self.reasoning_trace.append(f"YouTube video analyzed: {url}") | |
| self.reasoning_trace.append(f"Final answer: {answer}") | |
| return self.format_answer(answer), self.reasoning_trace | |
| if file_name: | |
| file_id = file_name.split('.')[0] | |
| local_file = self.download_file(file_id, file_name) | |
| if local_file: | |
| file_type = self.detect_file_type(local_file) | |
| file_content = self.analyze_file(local_file, file_type) | |
| # Plan: choose tool based on question and file | |
| if file_type == 'audio' or file_type == 'text': | |
| if file_content: | |
| answer = self.tools['extractive_qa'](q, file_content) | |
| else: | |
| answer = self.tools['llama3_chat'](q) | |
| elif file_type == 'excel' or file_type == 'csv': | |
| if file_content: | |
| answer = self.tools['table_qa'](q, file_content) | |
| else: | |
| answer = self.tools['llama3_chat'](q) | |
| elif file_type == 'image': | |
| if file_content: | |
| answer = self.tools['llama3_chat'](f"{q}\nImage description: {file_content}") | |
| else: | |
| answer = self.tools['llama3_chat'](q) | |
| elif file_type == 'code': | |
| answer = file_content | |
| else: | |
| answer = self.tools['llama3_chat'](q) | |
| self.reasoning_trace.append(f"Final answer: {answer}") | |
| return self.format_answer(answer), self.reasoning_trace | |
| def format_answer(self, answer): | |
| # GAIA compliance: remove extra words, units, articles, etc. | |
| if isinstance(answer, str): | |
| answer = answer.strip().rstrip('.') | |
| # Remove common prefixes | |
| for prefix in ['answer:', 'result:', 'the answer is', 'final answer:', 'response:']: | |
| if answer.lower().startswith(prefix): | |
| answer = answer[len(prefix):].strip() | |
| # Remove articles | |
| import re | |
| answer = re.sub(r'\b(the|a|an)\b ', '', answer, flags=re.IGNORECASE) | |
| # Remove trailing punctuation | |
| answer = answer.strip().rstrip('.') | |
| return answer | |
| def run(self, from_api=True, questions_path="Hugging Face Questions"): | |
| questions = self.fetch_questions(from_api=from_api, questions_path=questions_path) | |
| results = [] | |
| for qobj in questions: | |
| answer, trace = self.answer_question(qobj) | |
| results.append({ | |
| "task_id": qobj["task_id"], | |
| "answer": answer, | |
| "reasoning_trace": trace | |
| }) | |
| return results | |
| # --- Usage Example --- | |
| # agent = ModularGAIAAgent() | |
| # results = agent.run() | |
| # for r in results: | |
| # print(r) | |