iQuentin commited on
Commit
2b967a2
·
verified ·
1 Parent(s): 3d525e8

First draft of GAIA simu

Browse files
Files changed (1) hide show
  1. simuGAIA.py +165 -0
simuGAIA.py CHANGED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import sys
4
+ import json
5
+ import tempfile
6
+ from typing import List, Dict, Any, Optional
7
+ import traceback
8
+ # vimport dotenv
9
+
10
+ # Load environment variables from .env file
11
+ # dotenv.load_dotenv()
12
+
13
+ # Import our agent
14
+ from agent import QAgent
15
+
16
+ # Simulation of GAIA benchmark questions
17
+ SAMPLE_QUESTIONS = [
18
+ {
19
+ "task_id": "task_001",
20
+ "question": "What is the capital of France?",
21
+ "expected_answer": "Paris",
22
+ "has_file": False,
23
+ "file_content": None
24
+ }
25
+ ]
26
+
27
+ SAMPLE_QUESTIONS_OUT = [
28
+ {
29
+ "task_id": "task_002",
30
+ "question": "What is the square root of 144?",
31
+ "expected_answer": "12",
32
+ "has_file": False,
33
+ "file_content": None
34
+ },
35
+ {
36
+ "task_id": "task_003",
37
+ "question": "If a train travels at 60 miles per hour, how far will it travel in 2.5 hours?",
38
+ "expected_answer": "150 miles",
39
+ "has_file": False,
40
+ "file_content": None
41
+ },
42
+ {
43
+ "task_id": "task_004",
44
+ "question": ".rewsna eht sa 'thgir' drow eht etirw ,tfel fo etisoppo eht si tahW",
45
+ "expected_answer": "right",
46
+ "has_file": False,
47
+ "file_content": None
48
+ },
49
+ {
50
+ "task_id": "task_005",
51
+ "question": "Analyze the data in the attached CSV file and tell me the total sales for the month of January.",
52
+ "expected_answer": "$10,250.75",
53
+ "has_file": True,
54
+ "file_content": """Date,Product,Quantity,Price,Total
55
+ 2023-01-05,Widget A,10,25.99,259.90
56
+ 2023-01-12,Widget B,5,45.50,227.50
57
+ 2023-01-15,Widget C,20,50.25,1005.00
58
+ 2023-01-20,Widget A,15,25.99,389.85
59
+ 2023-01-25,Widget B,8,45.50,364.00
60
+ 2023-01-28,Widget D,100,80.04,8004.50"""
61
+ },
62
+ {
63
+ "task_id": "task_006",
64
+ "question": "I'm making a grocery list for my mom, but she's a picky eater. She only eats foods that don't contain the letter 'e'. List 5 common fruits and vegetables she can eat.",
65
+ "expected_answer": "Banana, Kiwi, Corn, Fig, Taro",
66
+ "has_file": False,
67
+ "file_content": None
68
+ },
69
+ {
70
+ "task_id": "task_007",
71
+ "question": "How many studio albums were published by Mercedes Sosa between 1972 and 1985?",
72
+ "expected_answer": "12",
73
+ "has_file": False,
74
+ "file_content": None
75
+ },
76
+ {
77
+ "task_id": "task_008",
78
+ "question": "In the video https://www.youtube.com/watch?v=L1vXC1KMRd0, what color is primarily associated with the main character?",
79
+ "expected_answer": "Blue",
80
+ "has_file": False,
81
+ "file_content": None
82
+ }
83
+ ]
84
+
85
+
86
+ def save_test_file(task_id: str, content: str) -> str:
87
+ """Save a test file to a temporary location."""
88
+ temp_dir = tempfile.gettempdir()
89
+ file_path = os.path.join(temp_dir, f"test_file_{task_id}.csv")
90
+
91
+ with open(file_path, 'w') as f:
92
+ f.write(content)
93
+
94
+ return file_path
95
+
96
+
97
+
98
+ def run_GAIA_questions_simu():
99
+ """
100
+ Used only during development for test that simulate GAIA questions.
101
+ """
102
+ # 1. Instantiate Agent
103
+ try:
104
+ agent = QAgent()
105
+ except Exception as e:
106
+ print(f"Error instantiating agent for GAIA simulation: {e}")
107
+ return f"Error initializing agent for GAIA simulation: {e}", None
108
+
109
+ results = []
110
+ correct_count = 0
111
+ total_count = len(SAMPLE_QUESTIONS)
112
+
113
+ for idx, question_data in enumerate(SAMPLE_QUESTIONS):
114
+ task_id = question_data["task_id"]
115
+ question = question_data["question"]
116
+ expected = question_data["expected_answer"]
117
+
118
+ print(f"\n{'='*80}")
119
+ print(f"Question {idx+1}/{total_count}: {question}")
120
+ print(f"Expected: {expected}")
121
+
122
+ # Process any attached file
123
+ # file_path = None
124
+ # if question_data["has_file"] and question_data["file_content"]:
125
+ # file_path = save_test_file(task_id, question_data["file_content"])
126
+ # print(f"Created test file: {file_path}")
127
+
128
+ # Get answer from agent
129
+ try:
130
+ answer = agent.invoke(question) # , file_path)
131
+ print(f"Agent answer: {answer}")
132
+
133
+ # Check if answer matches expected
134
+ is_correct = answer.lower() == expected.lower()
135
+ if is_correct:
136
+ correct_count += 1
137
+ print(f"✅ CORRECT")
138
+ else:
139
+ print(f"❌ INCORRECT - Expected: {expected}")
140
+
141
+ results.append({
142
+ "task_id": task_id,
143
+ "question": question,
144
+ "expected": expected,
145
+ "answer": answer,
146
+ "is_correct": is_correct
147
+ })
148
+ except Exception as e:
149
+ error_details = traceback.format_exc()
150
+ print(f"Error processing question: {e}\n{error_details}")
151
+ results.append({
152
+ "task_id": task_id,
153
+ "question": question,
154
+ "expected": expected,
155
+ "answer": f"ERROR: {str(e)}",
156
+ "is_correct": False
157
+ })
158
+
159
+ # Print summary
160
+ accuracy = (correct_count / total_count) * 100
161
+ print(f"\n{'='*80}")
162
+ print(f"Test Results: {correct_count}/{total_count} correct ({accuracy:.1f}%)")
163
+
164
+ return results
165
+