Update gaia_agent.py
Browse files- gaia_agent.py +98 -12
gaia_agent.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
Enhanced GAIA Agent with
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -15,7 +15,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
|
15 |
class EnhancedGAIAAgent:
|
16 |
"""
|
17 |
An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
|
18 |
-
with LLM-powered flexibility
|
19 |
"""
|
20 |
|
21 |
def __init__(self, model_name="google/flan-t5-large", device=None):
|
@@ -64,21 +64,85 @@ class EnhancedGAIAAgent:
|
|
64 |
self.tokenizer = None
|
65 |
self.model = None
|
66 |
|
67 |
-
def __call__(self, question: str) -> str:
|
68 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
print(f"Processing question: {question}")
|
70 |
|
71 |
# Determine question type
|
72 |
question_type = self._classify_question(question)
|
73 |
print(f"Classified as: {question_type}")
|
74 |
|
75 |
-
#
|
76 |
-
|
|
|
|
|
|
|
77 |
|
78 |
# Ensure answer is concise and specific
|
79 |
-
|
80 |
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
def _classify_question(self, question: str) -> str:
|
84 |
"""Determine the type of question for specialized handling."""
|
@@ -503,15 +567,25 @@ class EvaluationRunner:
|
|
503 |
continue
|
504 |
|
505 |
try:
|
506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
answers_payload.append({
|
508 |
"task_id": task_id,
|
509 |
"submitted_answer": submitted_answer
|
510 |
})
|
|
|
511 |
results_log.append({
|
512 |
"Task ID": task_id,
|
513 |
"Question": question_text,
|
514 |
-
"Submitted Answer": submitted_answer
|
|
|
515 |
})
|
516 |
except Exception as e:
|
517 |
print(f"Error running agent on task {task_id}: {e}")
|
@@ -598,9 +672,21 @@ def test_agent():
|
|
598 |
|
599 |
print("\n=== AGENT TEST RESULTS ===")
|
600 |
for question in test_questions:
|
601 |
-
|
|
|
|
|
|
|
|
|
|
|
602 |
print(f"\nQ: {question}")
|
603 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
|
605 |
return "Test completed successfully"
|
606 |
|
|
|
1 |
"""
|
2 |
+
Enhanced GAIA Agent with Strict Output Formatting for Hugging Face Course
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
15 |
class EnhancedGAIAAgent:
|
16 |
"""
|
17 |
An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
|
18 |
+
with LLM-powered flexibility and strict output formatting.
|
19 |
"""
|
20 |
|
21 |
def __init__(self, model_name="google/flan-t5-large", device=None):
|
|
|
64 |
self.tokenizer = None
|
65 |
self.model = None
|
66 |
|
67 |
+
def __call__(self, question: str, task_id: str = None) -> str:
|
68 |
+
"""
|
69 |
+
Process a question and return a formatted answer according to GAIA benchmark requirements.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
question: The question to answer
|
73 |
+
task_id: Optional task ID for the GAIA benchmark
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
JSON string with the required GAIA format
|
77 |
+
"""
|
78 |
print(f"Processing question: {question}")
|
79 |
|
80 |
# Determine question type
|
81 |
question_type = self._classify_question(question)
|
82 |
print(f"Classified as: {question_type}")
|
83 |
|
84 |
+
# Generate reasoning trace if appropriate
|
85 |
+
reasoning_trace = self._generate_reasoning_trace(question, question_type)
|
86 |
+
|
87 |
+
# Use the appropriate handler to get the answer
|
88 |
+
model_answer = self.handlers[question_type](question)
|
89 |
|
90 |
# Ensure answer is concise and specific
|
91 |
+
model_answer = self._ensure_concise_answer(model_answer, question_type)
|
92 |
|
93 |
+
# Format the response according to GAIA requirements
|
94 |
+
response = {
|
95 |
+
"task_id": task_id if task_id else "unknown_task",
|
96 |
+
"model_answer": model_answer,
|
97 |
+
"reasoning_trace": reasoning_trace
|
98 |
+
}
|
99 |
+
|
100 |
+
# Return the formatted JSON response
|
101 |
+
return json.dumps(response, ensure_ascii=False)
|
102 |
+
|
103 |
+
def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
|
104 |
+
"""Generate a reasoning trace for the question if appropriate."""
|
105 |
+
# For calculation and reasoning questions, provide a trace
|
106 |
+
if question_type == 'calculation':
|
107 |
+
# Extract numbers and operation from the question
|
108 |
+
numbers = re.findall(r'\d+', question)
|
109 |
+
|
110 |
+
if len(numbers) >= 2:
|
111 |
+
if re.search(r'(sum|add|plus|\+)', question.lower()):
|
112 |
+
return f"To find the sum, I add the numbers: {' + '.join(numbers)} = {sum(int(num) for num in numbers)}"
|
113 |
+
elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
|
114 |
+
return f"To find the difference, I subtract: {numbers[0]} - {numbers[1]} = {int(numbers[0]) - int(numbers[1])}"
|
115 |
+
elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
|
116 |
+
return f"To find the product, I multiply: {numbers[0]} × {numbers[1]} = {int(numbers[0]) * int(numbers[1])}"
|
117 |
+
elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2:
|
118 |
+
if int(numbers[1]) != 0:
|
119 |
+
return f"To find the quotient, I divide: {numbers[0]} ÷ {numbers[1]} = {int(numbers[0]) / int(numbers[1])}"
|
120 |
+
|
121 |
+
# If we can't generate a specific trace, use a generic one
|
122 |
+
return "I need to identify the numbers and operations in the question, then perform the calculation step by step."
|
123 |
+
|
124 |
+
elif question_type in ['factual', 'general'] and self.llm_available:
|
125 |
+
# For factual and general questions, use LLM to generate a trace
|
126 |
+
try:
|
127 |
+
prompt = f"Explain your reasoning for answering this question: {question}"
|
128 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
|
129 |
+
outputs = self.model.generate(
|
130 |
+
inputs["input_ids"],
|
131 |
+
max_length=150,
|
132 |
+
min_length=20,
|
133 |
+
temperature=0.3,
|
134 |
+
top_p=0.95,
|
135 |
+
do_sample=True,
|
136 |
+
num_return_sequences=1
|
137 |
+
)
|
138 |
+
|
139 |
+
trace = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
140 |
+
return trace[:200] # Limit trace length
|
141 |
+
except:
|
142 |
+
pass
|
143 |
+
|
144 |
+
# For other question types or if LLM fails, provide a minimal trace
|
145 |
+
return ""
|
146 |
|
147 |
def _classify_question(self, question: str) -> str:
|
148 |
"""Determine the type of question for specialized handling."""
|
|
|
567 |
continue
|
568 |
|
569 |
try:
|
570 |
+
# Call agent with task_id to ensure proper formatting
|
571 |
+
json_response = agent(question_text, task_id)
|
572 |
+
|
573 |
+
# Parse the JSON response
|
574 |
+
response_obj = json.loads(json_response)
|
575 |
+
|
576 |
+
# Extract the model_answer for submission
|
577 |
+
submitted_answer = response_obj.get("model_answer", "")
|
578 |
+
|
579 |
answers_payload.append({
|
580 |
"task_id": task_id,
|
581 |
"submitted_answer": submitted_answer
|
582 |
})
|
583 |
+
|
584 |
results_log.append({
|
585 |
"Task ID": task_id,
|
586 |
"Question": question_text,
|
587 |
+
"Submitted Answer": submitted_answer,
|
588 |
+
"Full Response": json_response
|
589 |
})
|
590 |
except Exception as e:
|
591 |
print(f"Error running agent on task {task_id}: {e}")
|
|
|
672 |
|
673 |
print("\n=== AGENT TEST RESULTS ===")
|
674 |
for question in test_questions:
|
675 |
+
# Generate a mock task_id for testing
|
676 |
+
task_id = f"test_{hash(question) % 10000}"
|
677 |
+
|
678 |
+
# Get formatted JSON response
|
679 |
+
json_response = agent(question, task_id)
|
680 |
+
|
681 |
print(f"\nQ: {question}")
|
682 |
+
print(f"Response: {json_response}")
|
683 |
+
|
684 |
+
# Parse and print the model_answer for clarity
|
685 |
+
try:
|
686 |
+
response_obj = json.loads(json_response)
|
687 |
+
print(f"Model Answer: {response_obj.get('model_answer', '')}")
|
688 |
+
except:
|
689 |
+
print("Error parsing JSON response")
|
690 |
|
691 |
return "Test completed successfully"
|
692 |
|