yoshizen commited on
Commit
98c40a0
·
verified ·
1 Parent(s): aade89a

Update gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +98 -12
gaia_agent.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Enhanced GAIA Agent with Hybrid Rule-LLM Architecture for Hugging Face Course
3
  """
4
 
5
  import os
@@ -15,7 +15,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
15
  class EnhancedGAIAAgent:
16
  """
17
  An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
18
- with LLM-powered flexibility for general knowledge and reasoning.
19
  """
20
 
21
  def __init__(self, model_name="google/flan-t5-large", device=None):
@@ -64,21 +64,85 @@ class EnhancedGAIAAgent:
64
  self.tokenizer = None
65
  self.model = None
66
 
67
- def __call__(self, question: str) -> str:
68
- """Process a question and return a specific, concise answer."""
 
 
 
 
 
 
 
 
 
69
  print(f"Processing question: {question}")
70
 
71
  # Determine question type
72
  question_type = self._classify_question(question)
73
  print(f"Classified as: {question_type}")
74
 
75
- # Use the appropriate handler
76
- answer = self.handlers[question_type](question)
 
 
 
77
 
78
  # Ensure answer is concise and specific
79
- answer = self._ensure_concise_answer(answer, question_type)
80
 
81
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def _classify_question(self, question: str) -> str:
84
  """Determine the type of question for specialized handling."""
@@ -503,15 +567,25 @@ class EvaluationRunner:
503
  continue
504
 
505
  try:
506
- submitted_answer = agent(question_text)
 
 
 
 
 
 
 
 
507
  answers_payload.append({
508
  "task_id": task_id,
509
  "submitted_answer": submitted_answer
510
  })
 
511
  results_log.append({
512
  "Task ID": task_id,
513
  "Question": question_text,
514
- "Submitted Answer": submitted_answer
 
515
  })
516
  except Exception as e:
517
  print(f"Error running agent on task {task_id}: {e}")
@@ -598,9 +672,21 @@ def test_agent():
598
 
599
  print("\n=== AGENT TEST RESULTS ===")
600
  for question in test_questions:
601
- answer = agent(question)
 
 
 
 
 
602
  print(f"\nQ: {question}")
603
- print(f"A: {answer}")
 
 
 
 
 
 
 
604
 
605
  return "Test completed successfully"
606
 
 
1
  """
2
+ Enhanced GAIA Agent with Strict Output Formatting for Hugging Face Course
3
  """
4
 
5
  import os
 
15
  class EnhancedGAIAAgent:
16
  """
17
  An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
18
+ with LLM-powered flexibility and strict output formatting.
19
  """
20
 
21
  def __init__(self, model_name="google/flan-t5-large", device=None):
 
64
  self.tokenizer = None
65
  self.model = None
66
 
67
+ def __call__(self, question: str, task_id: str = None) -> str:
68
+ """
69
+ Process a question and return a formatted answer according to GAIA benchmark requirements.
70
+
71
+ Args:
72
+ question: The question to answer
73
+ task_id: Optional task ID for the GAIA benchmark
74
+
75
+ Returns:
76
+ JSON string with the required GAIA format
77
+ """
78
  print(f"Processing question: {question}")
79
 
80
  # Determine question type
81
  question_type = self._classify_question(question)
82
  print(f"Classified as: {question_type}")
83
 
84
+ # Generate reasoning trace if appropriate
85
+ reasoning_trace = self._generate_reasoning_trace(question, question_type)
86
+
87
+ # Use the appropriate handler to get the answer
88
+ model_answer = self.handlers[question_type](question)
89
 
90
  # Ensure answer is concise and specific
91
+ model_answer = self._ensure_concise_answer(model_answer, question_type)
92
 
93
+ # Format the response according to GAIA requirements
94
+ response = {
95
+ "task_id": task_id if task_id else "unknown_task",
96
+ "model_answer": model_answer,
97
+ "reasoning_trace": reasoning_trace
98
+ }
99
+
100
+ # Return the formatted JSON response
101
+ return json.dumps(response, ensure_ascii=False)
102
+
103
+ def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
104
+ """Generate a reasoning trace for the question if appropriate."""
105
+ # For calculation and reasoning questions, provide a trace
106
+ if question_type == 'calculation':
107
+ # Extract numbers and operation from the question
108
+ numbers = re.findall(r'\d+', question)
109
+
110
+ if len(numbers) >= 2:
111
+ if re.search(r'(sum|add|plus|\+)', question.lower()):
112
+ return f"To find the sum, I add the numbers: {' + '.join(numbers)} = {sum(int(num) for num in numbers)}"
113
+ elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
114
+ return f"To find the difference, I subtract: {numbers[0]} - {numbers[1]} = {int(numbers[0]) - int(numbers[1])}"
115
+ elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
116
+ return f"To find the product, I multiply: {numbers[0]} × {numbers[1]} = {int(numbers[0]) * int(numbers[1])}"
117
+ elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2:
118
+ if int(numbers[1]) != 0:
119
+ return f"To find the quotient, I divide: {numbers[0]} ÷ {numbers[1]} = {int(numbers[0]) / int(numbers[1])}"
120
+
121
+ # If we can't generate a specific trace, use a generic one
122
+ return "I need to identify the numbers and operations in the question, then perform the calculation step by step."
123
+
124
+ elif question_type in ['factual', 'general'] and self.llm_available:
125
+ # For factual and general questions, use LLM to generate a trace
126
+ try:
127
+ prompt = f"Explain your reasoning for answering this question: {question}"
128
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
129
+ outputs = self.model.generate(
130
+ inputs["input_ids"],
131
+ max_length=150,
132
+ min_length=20,
133
+ temperature=0.3,
134
+ top_p=0.95,
135
+ do_sample=True,
136
+ num_return_sequences=1
137
+ )
138
+
139
+ trace = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
140
+ return trace[:200] # Limit trace length
141
+ except:
142
+ pass
143
+
144
+ # For other question types or if LLM fails, provide a minimal trace
145
+ return ""
146
 
147
  def _classify_question(self, question: str) -> str:
148
  """Determine the type of question for specialized handling."""
 
567
  continue
568
 
569
  try:
570
+ # Call agent with task_id to ensure proper formatting
571
+ json_response = agent(question_text, task_id)
572
+
573
+ # Parse the JSON response
574
+ response_obj = json.loads(json_response)
575
+
576
+ # Extract the model_answer for submission
577
+ submitted_answer = response_obj.get("model_answer", "")
578
+
579
  answers_payload.append({
580
  "task_id": task_id,
581
  "submitted_answer": submitted_answer
582
  })
583
+
584
  results_log.append({
585
  "Task ID": task_id,
586
  "Question": question_text,
587
+ "Submitted Answer": submitted_answer,
588
+ "Full Response": json_response
589
  })
590
  except Exception as e:
591
  print(f"Error running agent on task {task_id}: {e}")
 
672
 
673
  print("\n=== AGENT TEST RESULTS ===")
674
  for question in test_questions:
675
+ # Generate a mock task_id for testing
676
+ task_id = f"test_{hash(question) % 10000}"
677
+
678
+ # Get formatted JSON response
679
+ json_response = agent(question, task_id)
680
+
681
  print(f"\nQ: {question}")
682
+ print(f"Response: {json_response}")
683
+
684
+ # Parse and print the model_answer for clarity
685
+ try:
686
+ response_obj = json.loads(json_response)
687
+ print(f"Model Answer: {response_obj.get('model_answer', '')}")
688
+ except:
689
+ print("Error parsing JSON response")
690
 
691
  return "Test completed successfully"
692