Freddolin commited on
Commit
2f428df
·
verified ·
1 Parent(s): 15b9880

Create gaia_submission.py

Browse files
Files changed (1) hide show
  1. gaia_submission.py +164 -0
gaia_submission.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+ import os
4
+ from typing import List, Dict, Any
5
+ from gaia_agent import GaiaAgent
6
+
7
+ class GaiaSubmission:
8
+ def __init__(self, api_base_url: str, api_key: str = None):
9
+ self.api_base_url = api_base_url.rstrip('/')
10
+ self.api_key = api_key
11
+ self.agent = GaiaAgent()
12
+ self.headers = {'Content-Type': 'application/json'}
13
+
14
+ if api_key:
15
+ self.headers['Authorization'] = f'Bearer {api_key}'
16
+
17
+ def get_questions(self) -> List[Dict[str, Any]]:
18
+ """Hämta alla frågor från API:et"""
19
+ try:
20
+ response = requests.get(f"{self.api_base_url}/questions", headers=self.headers)
21
+ response.raise_for_status()
22
+ return response.json()
23
+ except Exception as e:
24
+ print(f"Error fetching questions: {e}")
25
+ return []
26
+
27
+ def get_random_question(self) -> Dict[str, Any]:
28
+ """Hämta en slumpmässig fråga"""
29
+ try:
30
+ response = requests.get(f"{self.api_base_url}/random-question", headers=self.headers)
31
+ response.raise_for_status()
32
+ return response.json()
33
+ except Exception as e:
34
+ print(f"Error fetching random question: {e}")
35
+ return {}
36
+
37
+ def download_file(self, task_id: str, file_path: str) -> bool:
38
+ """Ladda ned en fil associerad med en uppgift"""
39
+ try:
40
+ response = requests.get(f"{self.api_base_url}/files/{task_id}", headers=self.headers)
41
+ response.raise_for_status()
42
+
43
+ with open(file_path, 'wb') as f:
44
+ f.write(response.content)
45
+
46
+ return True
47
+ except Exception as e:
48
+ print(f"Error downloading file for task {task_id}: {e}")
49
+ return False
50
+
51
+ def submit_answer(self, task_id: str, answer: str, reasoning_trace: str = "") -> Dict[str, Any]:
52
+ """Skicka in svar till API:et"""
53
+ try:
54
+ submission = {
55
+ "task_id": task_id,
56
+ "model_answer": answer,
57
+ "reasoning_trace": reasoning_trace
58
+ }
59
+
60
+ response = requests.post(
61
+ f"{self.api_base_url}/submit",
62
+ headers=self.headers,
63
+ json=submission
64
+ )
65
+ response.raise_for_status()
66
+ return response.json()
67
+
68
+ except Exception as e:
69
+ print(f"Error submitting answer for task {task_id}: {e}")
70
+ return {"error": str(e)}
71
+
72
+ def process_single_question(self, question_data: Dict[str, Any]) -> Dict[str, Any]:
73
+ """Bearbeta en enskild fråga"""
74
+ task_id = question_data.get('task_id')
75
+ question = question_data.get('question', '')
76
+
77
+ print(f"Processing task {task_id}: {question[:100]}...")
78
+
79
+ # Kontrollera om det finns associerade filer
80
+ if 'files' in question_data:
81
+ for file_info in question_data['files']:
82
+ file_name = file_info.get('filename')
83
+ if file_name:
84
+ success = self.download_file(task_id, file_name)
85
+ if success:
86
+ print(f"Downloaded file: {file_name}")
87
+ else:
88
+ print(f"Failed to download file: {file_name}")
89
+
90
+ # Bearbeta frågan med agenten
91
+ try:
92
+ answer, reasoning_trace = self.agent(question)
93
+
94
+ result = {
95
+ "task_id": task_id,
96
+ "question": question,
97
+ "answer": answer,
98
+ "reasoning_trace": reasoning_trace,
99
+ "status": "success"
100
+ }
101
+
102
+ print(f"Answer: {answer}")
103
+ return result
104
+
105
+ except Exception as e:
106
+ error_msg = f"Error processing question: {str(e)}"
107
+ print(error_msg)
108
+
109
+ return {
110
+ "task_id": task_id,
111
+ "question": question,
112
+ "answer": "",
113
+ "reasoning_trace": error_msg,
114
+ "status": "error"
115
+ }
116
+
117
+ def run_evaluation(self, submit_answers: bool = False) -> List[Dict[str, Any]]:
118
+ """Kör utvärdering på alla frågor"""
119
+ questions = self.get_questions()
120
+ if not questions:
121
+ print("No questions retrieved. Exiting.")
122
+ return []
123
+
124
+ print(f"Retrieved {len(questions)} questions")
125
+ results = []
126
+
127
+ for i, question_data in enumerate(questions, 1):
128
+ print(f"\n--- Question {i}/{len(questions)} ---")
129
+
130
+ result = self.process_single_question(question_data)
131
+ results.append(result)
132
+
133
+ # Skicka in svar om det är aktiverat
134
+ if submit_answers and result['status'] == 'success':
135
+ submission_result = self.submit_answer(
136
+ result['task_id'],
137
+ result['answer'],
138
+ result['reasoning_trace']
139
+ )
140
+ result['submission_result'] = submission_result
141
+ print(f"Submission result: {submission_result}")
142
+
143
+ return results
144
+
145
+ def save_results(self, results: List[Dict[str, Any]], filename: str = "gaia_results.json"):
146
+ """Spara resultat till fil"""
147
+ with open(filename, 'w', encoding='utf-8') as f:
148
+ json.dump(results, f, indent=2, ensure_ascii=False)
149
+
150
+ print(f"Results saved to {filename}")
151
+
152
+ def save_submission_format(self, results: List[Dict[str, Any]], filename: str = "gaia_submission.jsonl"):
153
+ """Spara resultat i GAIA submission format"""
154
+ with open(filename, 'w', encoding='utf-8') as f:
155
+ for result in results:
156
+ if result['status'] == 'success':
157
+ submission_entry = {
158
+ "task_id": result['task_id'],
159
+ "model_answer": result['answer'],
160
+ "reasoning_trace": result['reasoning_trace']
161
+ }
162
+ f.write(json.dumps(submission_entry, ensure_ascii=False) + '\n')
163
+
164
+ print(f"Submission file saved to {filename}")