yoshizen commited on
Commit
985047d
·
verified ·
1 Parent(s): 98c40a0

Update gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +120 -28
gaia_agent.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Enhanced GAIA Agent with Strict Output Formatting for Hugging Face Course
3
  """
4
 
5
  import os
@@ -490,6 +490,12 @@ class EvaluationRunner:
490
  self.api_url = api_url
491
  self.questions_url = f"{api_url}/questions"
492
  self.submit_url = f"{api_url}/submit"
 
 
 
 
 
 
493
 
494
  def run_evaluation(self,
495
  agent: Any,
@@ -500,8 +506,13 @@ class EvaluationRunner:
500
  1. Fetch questions
501
  2. Run agent on all questions
502
  3. Submit answers
503
- 4. Return results
 
504
  """
 
 
 
 
505
  # Fetch questions
506
  questions_data = self._fetch_questions()
507
  if isinstance(questions_data, str): # Error message
@@ -515,7 +526,10 @@ class EvaluationRunner:
515
  # Submit answers
516
  submission_result = self._submit_answers(username, agent_code_url, answers_payload)
517
 
518
- # Return results
 
 
 
519
  return submission_result, results_log
520
 
521
  def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
@@ -531,7 +545,8 @@ class EvaluationRunner:
531
  print(error_msg)
532
  return error_msg
533
 
534
- print(f"Successfully fetched {len(questions_data)} questions.")
 
535
  return questions_data
536
 
537
  except requests.exceptions.RequestException as e:
@@ -609,33 +624,95 @@ class EvaluationRunner:
609
  }
610
 
611
  print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
612
- try:
613
- response = requests.post(
614
- self.submit_url,
615
- json=submission_data,
616
- headers={"Content-Type": "application/json"},
617
- timeout=30
618
- )
619
- response.raise_for_status()
620
-
621
  try:
622
- result = response.json()
623
- score = result.get("score")
624
- max_score = result.get("max_score")
 
 
 
 
 
625
 
626
- if score is not None and max_score is not None:
627
- return f"Evaluation complete! Score: {score}/{max_score}"
628
- else:
629
- return f"Submission successful, but score not returned. Response: {response.text}"
630
 
631
- except requests.exceptions.JSONDecodeError:
632
- return f"Submission successful, but response was not JSON. Response: {response.text}"
633
-
634
- except requests.exceptions.RequestException as e:
635
- return f"Error submitting answers: {e}"
636
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
637
  except Exception as e:
638
- return f"An unexpected error occurred during submission: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
 
641
  # Example usage and test cases
@@ -671,6 +748,9 @@ def test_agent():
671
  ]
672
 
673
  print("\n=== AGENT TEST RESULTS ===")
 
 
 
674
  for question in test_questions:
675
  # Generate a mock task_id for testing
676
  task_id = f"test_{hash(question) % 10000}"
@@ -684,10 +764,22 @@ def test_agent():
684
  # Parse and print the model_answer for clarity
685
  try:
686
  response_obj = json.loads(json_response)
687
- print(f"Model Answer: {response_obj.get('model_answer', '')}")
 
 
 
 
 
 
688
  except:
689
  print("Error parsing JSON response")
690
 
 
 
 
 
 
 
691
  return "Test completed successfully"
692
 
693
 
 
1
  """
2
+ Enhanced GAIA Agent with Strict Output Formatting and Answer Logging for Hugging Face Course
3
  """
4
 
5
  import os
 
490
  self.api_url = api_url
491
  self.questions_url = f"{api_url}/questions"
492
  self.submit_url = f"{api_url}/submit"
493
+ self.results_url = f"{api_url}/results"
494
+
495
+ # Initialize counters for tracking correct answers
496
+ self.total_questions = 0
497
+ self.correct_answers = 0
498
+ self.ground_truth = {} # Store ground truth answers if available
499
 
500
  def run_evaluation(self,
501
  agent: Any,
 
506
  1. Fetch questions
507
  2. Run agent on all questions
508
  3. Submit answers
509
+ 4. Check results and count correct answers
510
+ 5. Return results
511
  """
512
+ # Reset counters
513
+ self.total_questions = 0
514
+ self.correct_answers = 0
515
+
516
  # Fetch questions
517
  questions_data = self._fetch_questions()
518
  if isinstance(questions_data, str): # Error message
 
526
  # Submit answers
527
  submission_result = self._submit_answers(username, agent_code_url, answers_payload)
528
 
529
+ # Try to fetch results to count correct answers
530
+ self._check_results(username)
531
+
532
+ # Return results with correct answer count
533
  return submission_result, results_log
534
 
535
  def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
 
545
  print(error_msg)
546
  return error_msg
547
 
548
+ self.total_questions = len(questions_data)
549
+ print(f"Successfully fetched {self.total_questions} questions.")
550
  return questions_data
551
 
552
  except requests.exceptions.RequestException as e:
 
624
  }
625
 
626
  print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
627
+ max_retries = 3
628
+ retry_delay = 5 # seconds
629
+
630
+ for attempt in range(1, max_retries + 1):
 
 
 
 
 
631
  try:
632
+ print(f"Submission attempt {attempt} of {max_retries}...")
633
+ response = requests.post(
634
+ self.submit_url,
635
+ json=submission_data,
636
+ headers={"Content-Type": "application/json"},
637
+ timeout=30
638
+ )
639
+ response.raise_for_status()
640
 
641
+ try:
642
+ result = response.json()
643
+ score = result.get("score")
644
+ max_score = result.get("max_score")
645
 
646
+ if score is not None and max_score is not None:
647
+ self.correct_answers = score # Update correct answers count
648
+ return f"Evaluation complete! Score: {score}/{max_score}"
649
+ else:
650
+ print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
651
+ time.sleep(retry_delay)
652
+ continue
653
+
654
+ except requests.exceptions.JSONDecodeError:
655
+ print(f"Submission attempt {attempt}: Response was not JSON. Response: {response.text}")
656
+ if attempt < max_retries:
657
+ print(f"Waiting {retry_delay} seconds before retry...")
658
+ time.sleep(retry_delay)
659
+ else:
660
+ return f"Submission successful, but response was not JSON. Response: {response.text}"
661
+
662
+ except requests.exceptions.RequestException as e:
663
+ print(f"Submission attempt {attempt} failed: {e}")
664
+ if attempt < max_retries:
665
+ print(f"Waiting {retry_delay} seconds before retry...")
666
+ time.sleep(retry_delay)
667
+ else:
668
+ return f"Error submitting answers after {max_retries} attempts: {e}"
669
+
670
+ # If we get here, all retries failed but didn't raise exceptions
671
+ return "Submission Successful, but results are pending!"
672
+
673
+ def _check_results(self, username: str) -> None:
674
+ """Check results to count correct answers."""
675
+ try:
676
+ results_url = f"{self.results_url}?username={username}"
677
+ print(f"Checking results at: {results_url}")
678
+
679
+ response = requests.get(results_url, timeout=15)
680
+ if response.status_code == 200:
681
+ try:
682
+ data = response.json()
683
+ if isinstance(data, dict):
684
+ score = data.get("score")
685
+ if score is not None:
686
+ self.correct_answers = int(score)
687
+ print(f"✓ Correct answers: {self.correct_answers}/{self.total_questions}")
688
+ else:
689
+ print("Score information not available in results")
690
+ else:
691
+ print("Results data is not in expected format")
692
+ except:
693
+ print("Could not parse results JSON")
694
+ else:
695
+ print(f"Could not fetch results, status code: {response.status_code}")
696
  except Exception as e:
697
+ print(f"Error checking results: {e}")
698
+
699
+ def get_correct_answers_count(self) -> int:
700
+ """Get the number of correct answers."""
701
+ return self.correct_answers
702
+
703
+ def get_total_questions_count(self) -> int:
704
+ """Get the total number of questions."""
705
+ return self.total_questions
706
+
707
+ def print_evaluation_summary(self, username: str) -> None:
708
+ """Print a summary of the evaluation results."""
709
+ print("\n===== EVALUATION SUMMARY =====")
710
+ print(f"User: {username}")
711
+ print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
712
+ print(f"Correct Answers: {self.correct_answers}")
713
+ print(f"Total Questions: {self.total_questions}")
714
+ print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
715
+ print("=============================\n")
716
 
717
 
718
  # Example usage and test cases
 
748
  ]
749
 
750
  print("\n=== AGENT TEST RESULTS ===")
751
+ correct_count = 0
752
+ total_count = len(test_questions)
753
+
754
  for question in test_questions:
755
  # Generate a mock task_id for testing
756
  task_id = f"test_{hash(question) % 10000}"
 
764
  # Parse and print the model_answer for clarity
765
  try:
766
  response_obj = json.loads(json_response)
767
+ model_answer = response_obj.get('model_answer', '')
768
+ print(f"Model Answer: {model_answer}")
769
+
770
+ # For testing purposes, simulate correct answers
771
+ # In a real scenario, this would compare with ground truth
772
+ if len(model_answer) > 0 and not model_answer.startswith("AGENT ERROR"):
773
+ correct_count += 1
774
  except:
775
  print("Error parsing JSON response")
776
 
777
+ # Print test summary with correct answer count
778
+ print("\n===== TEST SUMMARY =====")
779
+ print(f"Correct Answers: {correct_count}/{total_count}")
780
+ print(f"Accuracy: {(correct_count / total_count * 100):.1f}%")
781
+ print("=======================\n")
782
+
783
  return "Test completed successfully"
784
 
785