yoshizen commited on
Commit
7999c2e
·
verified ·
1 Parent(s): 90346c1

Update gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +31 -13
gaia_agent.py CHANGED
@@ -73,7 +73,7 @@ class EnhancedGAIAAgent:
73
  task_id: Optional task ID for the GAIA benchmark
74
 
75
  Returns:
76
- Plain string with the answer (not JSON)
77
  """
78
  print(f"Processing question: {question}")
79
 
@@ -87,8 +87,12 @@ class EnhancedGAIAAgent:
87
  # Ensure answer is concise and specific
88
  model_answer = self._ensure_concise_answer(model_answer, question_type)
89
 
90
- # FIXED: Return only the plain string answer, not JSON
91
- return model_answer
 
 
 
 
92
 
93
  def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
94
  """Generate a reasoning trace for the question if appropriate."""
@@ -537,10 +541,15 @@ class EvaluationRunner:
537
  continue
538
 
539
  try:
540
- # FIXED: Call agent and get plain string answer
541
- submitted_answer = agent(question_text, task_id)
 
 
 
 
 
 
542
 
543
- # FIXED: No need to parse JSON, just use the answer directly
544
  answers_payload.append({
545
  "task_id": task_id,
546
  "submitted_answer": submitted_answer
@@ -549,7 +558,8 @@ class EvaluationRunner:
549
  results_log.append({
550
  "Task ID": task_id,
551
  "Question": question_text,
552
- "Submitted Answer": submitted_answer
 
553
  })
554
  except Exception as e:
555
  print(f"Error running agent on task {task_id}: {e}")
@@ -704,15 +714,23 @@ def test_agent():
704
  # Generate a mock task_id for testing
705
  task_id = f"test_{hash(question) % 10000}"
706
 
707
- # Get plain string answer
708
- answer = agent(question, task_id)
709
 
710
  print(f"\nQ: {question}")
711
- print(f"A: {answer}")
712
 
713
- # For testing purposes, simulate correct answers
714
- if len(answer) > 0 and not answer.startswith("AGENT ERROR"):
715
- correct_count += 1
 
 
 
 
 
 
 
 
716
 
717
  # Print test summary with correct answer count
718
  print("\n===== TEST SUMMARY =====")
 
73
  task_id: Optional task ID for the GAIA benchmark
74
 
75
  Returns:
76
+ JSON string with final_answer key
77
  """
78
  print(f"Processing question: {question}")
79
 
 
87
  # Ensure answer is concise and specific
88
  model_answer = self._ensure_concise_answer(model_answer, question_type)
89
 
90
+ # FIXED: Return JSON with final_answer key
91
+ response = {
92
+ "final_answer": model_answer
93
+ }
94
+
95
+ return json.dumps(response)
96
 
97
  def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
98
  """Generate a reasoning trace for the question if appropriate."""
 
541
  continue
542
 
543
  try:
544
+ # Call agent with task_id to ensure proper formatting
545
+ json_response = agent(question_text, task_id)
546
+
547
+ # Parse the JSON response
548
+ response_obj = json.loads(json_response)
549
+
550
+ # Extract the final_answer for submission
551
+ submitted_answer = response_obj.get("final_answer", "")
552
 
 
553
  answers_payload.append({
554
  "task_id": task_id,
555
  "submitted_answer": submitted_answer
 
558
  results_log.append({
559
  "Task ID": task_id,
560
  "Question": question_text,
561
+ "Submitted Answer": submitted_answer,
562
+ "Full Response": json_response
563
  })
564
  except Exception as e:
565
  print(f"Error running agent on task {task_id}: {e}")
 
714
  # Generate a mock task_id for testing
715
  task_id = f"test_{hash(question) % 10000}"
716
 
717
+ # Get JSON response with final_answer
718
+ json_response = agent(question, task_id)
719
 
720
  print(f"\nQ: {question}")
721
+ print(f"Response: {json_response}")
722
 
723
+ # Parse and print the final_answer for clarity
724
+ try:
725
+ response_obj = json.loads(json_response)
726
+ final_answer = response_obj.get('final_answer', '')
727
+ print(f"Final Answer: {final_answer}")
728
+
729
+ # For testing purposes, simulate correct answers
730
+ if len(final_answer) > 0 and not final_answer.startswith("AGENT ERROR"):
731
+ correct_count += 1
732
+ except:
733
+ print("Error parsing JSON response")
734
 
735
  # Print test summary with correct answer count
736
  print("\n===== TEST SUMMARY =====")