Update gaia_agent.py
Browse files- gaia_agent.py +120 -28
gaia_agent.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
"""
|
2 |
-
Enhanced GAIA Agent with Strict Output Formatting for Hugging Face Course
|
3 |
"""
|
4 |
|
5 |
import os
|
@@ -490,6 +490,12 @@ class EvaluationRunner:
|
|
490 |
self.api_url = api_url
|
491 |
self.questions_url = f"{api_url}/questions"
|
492 |
self.submit_url = f"{api_url}/submit"
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
494 |
def run_evaluation(self,
|
495 |
agent: Any,
|
@@ -500,8 +506,13 @@ class EvaluationRunner:
|
|
500 |
1. Fetch questions
|
501 |
2. Run agent on all questions
|
502 |
3. Submit answers
|
503 |
-
4.
|
|
|
504 |
"""
|
|
|
|
|
|
|
|
|
505 |
# Fetch questions
|
506 |
questions_data = self._fetch_questions()
|
507 |
if isinstance(questions_data, str): # Error message
|
@@ -515,7 +526,10 @@ class EvaluationRunner:
|
|
515 |
# Submit answers
|
516 |
submission_result = self._submit_answers(username, agent_code_url, answers_payload)
|
517 |
|
518 |
-
#
|
|
|
|
|
|
|
519 |
return submission_result, results_log
|
520 |
|
521 |
def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
|
@@ -531,7 +545,8 @@ class EvaluationRunner:
|
|
531 |
print(error_msg)
|
532 |
return error_msg
|
533 |
|
534 |
-
|
|
|
535 |
return questions_data
|
536 |
|
537 |
except requests.exceptions.RequestException as e:
|
@@ -609,33 +624,95 @@ class EvaluationRunner:
|
|
609 |
}
|
610 |
|
611 |
print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
headers={"Content-Type": "application/json"},
|
617 |
-
timeout=30
|
618 |
-
)
|
619 |
-
response.raise_for_status()
|
620 |
-
|
621 |
try:
|
622 |
-
|
623 |
-
|
624 |
-
|
|
|
|
|
|
|
|
|
|
|
625 |
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
637 |
except Exception as e:
|
638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
|
640 |
|
641 |
# Example usage and test cases
|
@@ -671,6 +748,9 @@ def test_agent():
|
|
671 |
]
|
672 |
|
673 |
print("\n=== AGENT TEST RESULTS ===")
|
|
|
|
|
|
|
674 |
for question in test_questions:
|
675 |
# Generate a mock task_id for testing
|
676 |
task_id = f"test_{hash(question) % 10000}"
|
@@ -684,10 +764,22 @@ def test_agent():
|
|
684 |
# Parse and print the model_answer for clarity
|
685 |
try:
|
686 |
response_obj = json.loads(json_response)
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
except:
|
689 |
print("Error parsing JSON response")
|
690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
691 |
return "Test completed successfully"
|
692 |
|
693 |
|
|
|
1 |
"""
|
2 |
+
Enhanced GAIA Agent with Strict Output Formatting and Answer Logging for Hugging Face Course
|
3 |
"""
|
4 |
|
5 |
import os
|
|
|
490 |
self.api_url = api_url
|
491 |
self.questions_url = f"{api_url}/questions"
|
492 |
self.submit_url = f"{api_url}/submit"
|
493 |
+
self.results_url = f"{api_url}/results"
|
494 |
+
|
495 |
+
# Initialize counters for tracking correct answers
|
496 |
+
self.total_questions = 0
|
497 |
+
self.correct_answers = 0
|
498 |
+
self.ground_truth = {} # Store ground truth answers if available
|
499 |
|
500 |
def run_evaluation(self,
|
501 |
agent: Any,
|
|
|
506 |
1. Fetch questions
|
507 |
2. Run agent on all questions
|
508 |
3. Submit answers
|
509 |
+
4. Check results and count correct answers
|
510 |
+
5. Return results
|
511 |
"""
|
512 |
+
# Reset counters
|
513 |
+
self.total_questions = 0
|
514 |
+
self.correct_answers = 0
|
515 |
+
|
516 |
# Fetch questions
|
517 |
questions_data = self._fetch_questions()
|
518 |
if isinstance(questions_data, str): # Error message
|
|
|
526 |
# Submit answers
|
527 |
submission_result = self._submit_answers(username, agent_code_url, answers_payload)
|
528 |
|
529 |
+
# Try to fetch results to count correct answers
|
530 |
+
self._check_results(username)
|
531 |
+
|
532 |
+
# Return results with correct answer count
|
533 |
return submission_result, results_log
|
534 |
|
535 |
def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
|
|
|
545 |
print(error_msg)
|
546 |
return error_msg
|
547 |
|
548 |
+
self.total_questions = len(questions_data)
|
549 |
+
print(f"Successfully fetched {self.total_questions} questions.")
|
550 |
return questions_data
|
551 |
|
552 |
except requests.exceptions.RequestException as e:
|
|
|
624 |
}
|
625 |
|
626 |
print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
|
627 |
+
max_retries = 3
|
628 |
+
retry_delay = 5 # seconds
|
629 |
+
|
630 |
+
for attempt in range(1, max_retries + 1):
|
|
|
|
|
|
|
|
|
|
|
631 |
try:
|
632 |
+
print(f"Submission attempt {attempt} of {max_retries}...")
|
633 |
+
response = requests.post(
|
634 |
+
self.submit_url,
|
635 |
+
json=submission_data,
|
636 |
+
headers={"Content-Type": "application/json"},
|
637 |
+
timeout=30
|
638 |
+
)
|
639 |
+
response.raise_for_status()
|
640 |
|
641 |
+
try:
|
642 |
+
result = response.json()
|
643 |
+
score = result.get("score")
|
644 |
+
max_score = result.get("max_score")
|
645 |
|
646 |
+
if score is not None and max_score is not None:
|
647 |
+
self.correct_answers = score # Update correct answers count
|
648 |
+
return f"Evaluation complete! Score: {score}/{max_score}"
|
649 |
+
else:
|
650 |
+
print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
|
651 |
+
time.sleep(retry_delay)
|
652 |
+
continue
|
653 |
+
|
654 |
+
except requests.exceptions.JSONDecodeError:
|
655 |
+
print(f"Submission attempt {attempt}: Response was not JSON. Response: {response.text}")
|
656 |
+
if attempt < max_retries:
|
657 |
+
print(f"Waiting {retry_delay} seconds before retry...")
|
658 |
+
time.sleep(retry_delay)
|
659 |
+
else:
|
660 |
+
return f"Submission successful, but response was not JSON. Response: {response.text}"
|
661 |
+
|
662 |
+
except requests.exceptions.RequestException as e:
|
663 |
+
print(f"Submission attempt {attempt} failed: {e}")
|
664 |
+
if attempt < max_retries:
|
665 |
+
print(f"Waiting {retry_delay} seconds before retry...")
|
666 |
+
time.sleep(retry_delay)
|
667 |
+
else:
|
668 |
+
return f"Error submitting answers after {max_retries} attempts: {e}"
|
669 |
+
|
670 |
+
# If we get here, all retries failed but didn't raise exceptions
|
671 |
+
return "Submission Successful, but results are pending!"
|
672 |
+
|
673 |
+
def _check_results(self, username: str) -> None:
|
674 |
+
"""Check results to count correct answers."""
|
675 |
+
try:
|
676 |
+
results_url = f"{self.results_url}?username={username}"
|
677 |
+
print(f"Checking results at: {results_url}")
|
678 |
+
|
679 |
+
response = requests.get(results_url, timeout=15)
|
680 |
+
if response.status_code == 200:
|
681 |
+
try:
|
682 |
+
data = response.json()
|
683 |
+
if isinstance(data, dict):
|
684 |
+
score = data.get("score")
|
685 |
+
if score is not None:
|
686 |
+
self.correct_answers = int(score)
|
687 |
+
print(f"✓ Correct answers: {self.correct_answers}/{self.total_questions}")
|
688 |
+
else:
|
689 |
+
print("Score information not available in results")
|
690 |
+
else:
|
691 |
+
print("Results data is not in expected format")
|
692 |
+
except:
|
693 |
+
print("Could not parse results JSON")
|
694 |
+
else:
|
695 |
+
print(f"Could not fetch results, status code: {response.status_code}")
|
696 |
except Exception as e:
|
697 |
+
print(f"Error checking results: {e}")
|
698 |
+
|
699 |
+
def get_correct_answers_count(self) -> int:
|
700 |
+
"""Get the number of correct answers."""
|
701 |
+
return self.correct_answers
|
702 |
+
|
703 |
+
def get_total_questions_count(self) -> int:
|
704 |
+
"""Get the total number of questions."""
|
705 |
+
return self.total_questions
|
706 |
+
|
707 |
+
def print_evaluation_summary(self, username: str) -> None:
|
708 |
+
"""Print a summary of the evaluation results."""
|
709 |
+
print("\n===== EVALUATION SUMMARY =====")
|
710 |
+
print(f"User: {username}")
|
711 |
+
print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
|
712 |
+
print(f"Correct Answers: {self.correct_answers}")
|
713 |
+
print(f"Total Questions: {self.total_questions}")
|
714 |
+
print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
|
715 |
+
print("=============================\n")
|
716 |
|
717 |
|
718 |
# Example usage and test cases
|
|
|
748 |
]
|
749 |
|
750 |
print("\n=== AGENT TEST RESULTS ===")
|
751 |
+
correct_count = 0
|
752 |
+
total_count = len(test_questions)
|
753 |
+
|
754 |
for question in test_questions:
|
755 |
# Generate a mock task_id for testing
|
756 |
task_id = f"test_{hash(question) % 10000}"
|
|
|
764 |
# Parse and print the model_answer for clarity
|
765 |
try:
|
766 |
response_obj = json.loads(json_response)
|
767 |
+
model_answer = response_obj.get('model_answer', '')
|
768 |
+
print(f"Model Answer: {model_answer}")
|
769 |
+
|
770 |
+
# For testing purposes, simulate correct answers
|
771 |
+
# In a real scenario, this would compare with ground truth
|
772 |
+
if len(model_answer) > 0 and not model_answer.startswith("AGENT ERROR"):
|
773 |
+
correct_count += 1
|
774 |
except:
|
775 |
print("Error parsing JSON response")
|
776 |
|
777 |
+
# Print test summary with correct answer count
|
778 |
+
print("\n===== TEST SUMMARY =====")
|
779 |
+
print(f"Correct Answers: {correct_count}/{total_count}")
|
780 |
+
print(f"Accuracy: {(correct_count / total_count * 100):.1f}%")
|
781 |
+
print("=======================\n")
|
782 |
+
|
783 |
return "Test completed successfully"
|
784 |
|
785 |
|