Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import json | |
import tempfile | |
from typing import List, Dict, Any, Optional | |
import traceback | |
# vimport dotenv | |
# Load environment variables from .env file | |
# dotenv.load_dotenv() | |
# Import our agent | |
from agent import QAgent | |
# Simulation of GAIA benchmark questions | |
SAMPLE_QUESTIONS = [ | |
{ | |
"task_id": "task_002", | |
"question": "What is the square root of 144?", | |
"expected_answer": "12", | |
"has_file": False, | |
"file_content": None | |
} | |
] | |
SAMPLE_QUESTIONS_OUT = [ | |
{ | |
"task_id": "task_001", | |
"question": "What is the capital of France?", | |
"expected_answer": "Paris", | |
"has_file": False, | |
"file_content": None | |
}, | |
{ | |
"task_id": "task_003", | |
"question": "If a train travels at 60 miles per hour, how far will it travel in 2.5 hours?", | |
"expected_answer": "150 miles", | |
"has_file": False, | |
"file_content": None | |
}, | |
{ | |
"task_id": "task_004", | |
"question": ".rewsna eht sa 'thgir' drow eht etirw ,tfel fo etisoppo eht si tahW", | |
"expected_answer": "right", | |
"has_file": False, | |
"file_content": None | |
}, | |
{ | |
"task_id": "task_005", | |
"question": "Analyze the data in the attached CSV file and tell me the total sales for the month of January.", | |
"expected_answer": "$10,250.75", | |
"has_file": True, | |
"file_content": """Date,Product,Quantity,Price,Total | |
2023-01-05,Widget A,10,25.99,259.90 | |
2023-01-12,Widget B,5,45.50,227.50 | |
2023-01-15,Widget C,20,50.25,1005.00 | |
2023-01-20,Widget A,15,25.99,389.85 | |
2023-01-25,Widget B,8,45.50,364.00 | |
2023-01-28,Widget D,100,80.04,8004.50""" | |
}, | |
{ | |
"task_id": "task_006", | |
"question": "I'm making a grocery list for my mom, but she's a picky eater. She only eats foods that don't contain the letter 'e'. List 5 common fruits and vegetables she can eat.", | |
"expected_answer": "Banana, Kiwi, Corn, Fig, Taro", | |
"has_file": False, | |
"file_content": None | |
}, | |
{ | |
"task_id": "task_007", | |
"question": "How many studio albums were published by Mercedes Sosa between 1972 and 1985?", | |
"expected_answer": "12", | |
"has_file": False, | |
"file_content": None | |
}, | |
{ | |
"task_id": "task_008", | |
"question": "In the video https://www.youtube.com/watch?v=L1vXC1KMRd0, what color is primarily associated with the main character?", | |
"expected_answer": "Blue", | |
"has_file": False, | |
"file_content": None | |
} | |
] | |
def init_agent(): | |
"""Initialize the QAgent.""" | |
print("Initializing QAgent...") | |
try: | |
agent = QAgent() | |
return agent | |
except Exception as e: | |
print(f"Error instantiating agent for GAIA simulation: {e}") | |
return None | |
def save_test_file(task_id: str, content: str) -> str: | |
"""Save a test file to a temporary location.""" | |
temp_dir = tempfile.gettempdir() | |
file_path = os.path.join(temp_dir, f"test_file_{task_id}.csv") | |
with open(file_path, 'w') as f: | |
f.write(content) | |
return file_path | |
def run_GAIA_questions_simu(): | |
""" | |
Used only during development for test that simulate GAIA questions. | |
""" | |
# 1. Instantiate Agent | |
agent = init_agent() | |
results = [] | |
correct_count = 0 | |
total_count = len(SAMPLE_QUESTIONS) | |
for idx, question_data in enumerate(SAMPLE_QUESTIONS): | |
task_id = question_data["task_id"] | |
question = question_data["question"] | |
expected = question_data["expected_answer"] | |
print(f"\n{'='*80}") | |
print(f"Question {idx+1}/{total_count}: {question}") | |
print(f"Expected: {expected}") | |
# Process any attached file | |
# file_path = None | |
# if question_data["has_file"] and question_data["file_content"]: | |
# file_path = save_test_file(task_id, question_data["file_content"]) | |
# print(f"Created test file: {file_path}") | |
# Get answer from agent | |
try: | |
answer = agent.invoke(question) # , file_path) | |
print(f"Agent answer: {answer}") | |
# Check if answer matches expected | |
is_correct = answer.lower() == expected.lower() | |
if is_correct: | |
correct_count += 1 | |
print(f"✅ CORRECT") | |
else: | |
print(f"❌ INCORRECT - Expected: {expected}") | |
results.append({ | |
"task_id": task_id, | |
"question": question, | |
"expected": expected, | |
"answer": answer, | |
"is_correct": is_correct | |
}) | |
except Exception as e: | |
error_details = traceback.format_exc() | |
print(f"Error processing question: {e}\n{error_details}") | |
results.append({ | |
"task_id": task_id, | |
"question": question, | |
"expected": expected, | |
"answer": f"ERROR: {str(e)}", | |
"is_correct": False | |
}) | |
# Print summary | |
accuracy = (correct_count / total_count) * 100 | |
print(f"\n{'='*80}") | |
print(f"Test Results: {correct_count}/{total_count} correct ({accuracy:.1f}%)") | |
return results | |