yoshizen commited on
Commit
b35115d
·
verified ·
1 Parent(s): 1dbf488

Update gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +166 -509
gaia_agent.py CHANGED
@@ -1,498 +1,225 @@
1
  """
2
- Enhanced GAIA Agent with Strict Output Formatting and Answer Logging for Hugging Face Course
3
  """
4
 
5
  import os
6
- import re
7
- import math
8
  import json
9
- import datetime
10
- import requests
11
- from typing import List, Dict, Any, Optional, Union, Tuple, Callable
12
  import torch
13
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
 
 
 
 
 
14
 
15
  class EnhancedGAIAAgent:
16
  """
17
- An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
18
- with LLM-powered flexibility and strict output formatting.
19
  """
20
 
21
- def __init__(self, model_name="google/flan-t5-large", device=None):
22
- """Initialize the agent with tools and model."""
23
- self.model_name = model_name
24
- print(f"EnhancedGAIAAgent initializing with model: {model_name}")
25
-
26
- # Initialize LLM components
27
- self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
28
- self._initialize_llm()
29
-
30
- # Register specialized handlers
31
- self.handlers = {
32
- 'calculation': self._handle_calculation,
33
- 'date_time': self._handle_date_time,
34
- 'list': self._handle_list_question,
35
- 'visual': self._handle_visual_question,
36
- 'factual': self._handle_factual_question,
37
- 'general': self._handle_general_question
38
- }
39
 
40
- # Define prompt templates
41
- self.prompt_templates = {
42
- 'calculation': "Solve this step by step: {question}",
43
- 'date_time': "Answer this date/time question precisely: {question}",
44
- 'list': "Provide a comma-separated list for: {question}",
45
- 'visual': "Describe what is shown in the image related to: {question}",
46
- 'factual': "Answer this question concisely: {question}",
47
- 'reasoning': "Let's think step by step: {question}",
48
- 'general': "Provide a specific, concise answer: {question}"
49
- }
50
 
51
- print("EnhancedGAIAAgent initialized successfully")
 
 
 
 
 
 
 
 
 
52
 
53
- def _initialize_llm(self):
54
- """Initialize the language model for fallback responses."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
- print(f"Loading model {self.model_name} on {self.device}")
57
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
58
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
59
- self.llm_available = True
60
- print("LLM initialized successfully")
61
  except Exception as e:
62
- print(f"Error initializing LLM: {e}")
63
- self.llm_available = False
64
- self.tokenizer = None
65
- self.model = None
66
 
67
- def __call__(self, question: str, task_id: str = None) -> str:
68
  """
69
- Process a question and return a formatted answer according to GAIA benchmark requirements.
70
 
71
  Args:
72
- question: The question to answer
73
- task_id: Optional task ID for the GAIA benchmark
74
 
75
  Returns:
76
- JSON string with final_answer key
77
  """
78
- print(f"Processing question: {question}")
79
-
80
- # Determine question type
81
- question_type = self._classify_question(question)
82
- print(f"Classified as: {question_type}")
83
-
84
- # Use the appropriate handler to get the answer
85
- model_answer = self.handlers[question_type](question)
86
-
87
- # Ensure answer is concise and specific
88
- model_answer = self._ensure_concise_answer(model_answer, question_type)
89
-
90
- # FIXED: Return JSON with final_answer key
91
- response = {
92
- "final_answer": model_answer
93
- }
94
-
95
- return json.dumps(response)
96
-
97
- def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
98
- """Generate a reasoning trace for the question if appropriate."""
99
- # For calculation and reasoning questions, provide a trace
100
- if question_type == 'calculation':
101
- # Extract numbers and operation from the question
102
- numbers = re.findall(r'\d+', question)
103
-
104
- if len(numbers) >= 2:
105
- if re.search(r'(sum|add|plus|\+)', question.lower()):
106
- return f"To find the sum, I add the numbers: {' + '.join(numbers)} = {sum(int(num) for num in numbers)}"
107
- elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
108
- return f"To find the difference, I subtract: {numbers[0]} - {numbers[1]} = {int(numbers[0]) - int(numbers[1])}"
109
- elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
110
- return f"To find the product, I multiply: {numbers[0]} × {numbers[1]} = {int(numbers[0]) * int(numbers[1])}"
111
- elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2:
112
- if int(numbers[1]) != 0:
113
- return f"To find the quotient, I divide: {numbers[0]} ÷ {numbers[1]} = {int(numbers[0]) / int(numbers[1])}"
114
-
115
- # If we can't generate a specific trace, use a generic one
116
- return "I need to identify the numbers and operations in the question, then perform the calculation step by step."
117
-
118
- elif question_type in ['factual', 'general'] and self.llm_available:
119
- # For factual and general questions, use LLM to generate a trace
120
- try:
121
- prompt = f"Explain your reasoning for answering this question: {question}"
122
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
123
- outputs = self.model.generate(
124
- inputs["input_ids"],
125
- max_length=150,
126
- min_length=20,
127
- temperature=0.3,
128
- top_p=0.95,
129
- do_sample=True,
130
- num_return_sequences=1
131
- )
132
-
133
- trace = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
134
- return trace[:200] # Limit trace length
135
- except:
136
- pass
137
-
138
- # For other question types or if LLM fails, provide a minimal trace
139
- return ""
140
-
141
- def _classify_question(self, question: str) -> str:
142
- """Determine the type of question for specialized handling."""
143
  question_lower = question.lower()
144
 
145
- # Check for calculation questions
146
- if self._is_calculation_question(question):
147
- return 'calculation'
148
-
149
- # Check for date/time questions
150
- elif self._is_date_time_question(question):
151
- return 'date_time'
152
-
153
- # Check for list questions
154
- elif self._is_list_question(question):
155
- return 'list'
156
-
157
- # Check for visual/image questions
158
- elif self._is_visual_question(question):
159
- return 'visual'
160
-
161
- # Check for factual questions
162
- elif self._is_factual_question(question):
163
- return 'factual'
164
-
165
- # Default to general knowledge
166
  else:
167
- return 'general'
168
-
169
- def _is_calculation_question(self, question: str) -> bool:
170
- """Check if the question requires mathematical calculation."""
171
- calculation_patterns = [
172
- r'\d+\s*[\+\-\*\/]\s*\d+', # Basic operations: 5+3, 10-2, etc.
173
- r'(sum|add|plus|subtract|minus|multiply|divide|product|quotient)',
174
- r'(calculate|compute|find|what is|how much|result)',
175
- r'(square root|power|exponent|factorial|percentage|average|mean)'
176
- ]
177
-
178
- return any(re.search(pattern, question.lower()) for pattern in calculation_patterns)
179
-
180
- def _is_date_time_question(self, question: str) -> bool:
181
- """Check if the question is about date or time."""
182
- date_time_patterns = [
183
- r'(date|time|day|month|year|hour|minute|second)',
184
- r'(today|tomorrow|yesterday|current|now)',
185
- r'(calendar|schedule|appointment)',
186
- r'(when|how long|duration|period)'
187
- ]
188
-
189
- return any(re.search(pattern, question.lower()) for pattern in date_time_patterns)
190
-
191
- def _is_list_question(self, question: str) -> bool:
192
- """Check if the question requires a list as an answer."""
193
- list_patterns = [
194
- r'(list|enumerate|items|elements)',
195
- r'comma.separated',
196
- r'(all|every|each).*(of|in)',
197
- r'(provide|give).*(list)'
198
- ]
199
-
200
- return any(re.search(pattern, question.lower()) for pattern in list_patterns)
201
 
202
- def _is_visual_question(self, question: str) -> bool:
203
- """Check if the question is about an image or visual content."""
204
- visual_patterns = [
205
- r'(image|picture|photo|graph|chart|diagram|figure)',
206
- r'(show|display|illustrate|depict)',
207
- r'(look|see|observe|view)',
208
- r'(visual|visually)'
209
- ]
210
-
211
- return any(re.search(pattern, question.lower()) for pattern in visual_patterns)
212
-
213
- def _is_factual_question(self, question: str) -> bool:
214
- """Check if the question is asking for a factual answer."""
215
- factual_patterns = [
216
- r'^(who|what|where|when|why|how)',
217
- r'(name|identify|specify|tell me)',
218
- r'(capital|president|inventor|author|creator|founder)',
219
- r'(located|situated|found|discovered)'
220
- ]
221
-
222
- return any(re.search(pattern, question.lower()) for pattern in factual_patterns)
223
-
224
- def _handle_calculation(self, question: str) -> str:
225
- """Handle mathematical calculation questions with precise answers."""
226
- # Extract numbers and operation from the question
227
- numbers = re.findall(r'\d+', question)
228
-
229
- # Try to extract a mathematical expression
230
- expression_match = re.search(r'\d+\s*[\+\-\*\/]\s*\d+', question)
231
-
232
- # Determine the operation
233
- if re.search(r'(sum|add|plus|\+)', question.lower()) and len(numbers) >= 2:
234
- result = sum(int(num) for num in numbers)
235
- return str(result)
236
-
237
- elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
238
- result = int(numbers[0]) - int(numbers[1])
239
- return str(result)
240
-
241
- elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
242
- result = int(numbers[0]) * int(numbers[1])
243
- return str(result)
244
-
245
- elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2 and int(numbers[1]) != 0:
246
- result = int(numbers[0]) / int(numbers[1])
247
- return str(result)
248
-
249
- # For more complex calculations, try to evaluate the expression
250
- elif expression_match:
251
- try:
252
- # Extract and clean the expression
253
- expr = expression_match.group(0)
254
- expr = expr.replace('plus', '+').replace('minus', '-')
255
- expr = expr.replace('times', '*').replace('divided by', '/')
256
-
257
- # Evaluate the expression
258
- result = eval(expr)
259
- return str(result)
260
- except:
261
- pass
262
-
263
- # If rule-based approach fails, use LLM with math-specific prompt
264
- return self._generate_llm_response(question, 'calculation')
265
-
266
- def _handle_date_time(self, question: str) -> str:
267
- """Handle date and time related questions."""
268
- now = datetime.datetime.now()
269
- question_lower = question.lower()
270
 
271
- if re.search(r'(today|current date|what day is it)', question_lower):
272
- return now.strftime("%Y-%m-%d")
273
-
274
- elif re.search(r'(time now|current time|what time is it)', question_lower):
275
- return now.strftime("%H:%M:%S")
276
-
277
- elif re.search(r'(day of the week|what day of the week)', question_lower):
278
- return now.strftime("%A")
279
-
280
- elif re.search(r'(month|current month|what month is it)', question_lower):
281
- return now.strftime("%B")
282
 
283
- elif re.search(r'(year|current year|what year is it)', question_lower):
284
- return now.strftime("%Y")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # For more complex date/time questions, use LLM
287
- return self._generate_llm_response(question, 'date_time')
288
 
289
- def _handle_list_question(self, question: str) -> str:
290
- """Handle questions requiring a list as an answer."""
291
- question_lower = question.lower()
292
 
293
- # Common list questions with specific answers
294
- if re.search(r'(fruit|fruits)', question_lower):
295
- return "apple, banana, orange, grape, strawberry"
296
-
297
- elif re.search(r'(vegetable|vegetables)', question_lower):
298
- return "carrot, broccoli, spinach, potato, onion"
299
-
300
- elif re.search(r'(country|countries)', question_lower):
301
- return "USA, China, India, Russia, Brazil"
302
-
303
- elif re.search(r'(capital|capitals)', question_lower):
304
- return "Washington D.C., Beijing, New Delhi, Moscow, Brasilia"
305
 
306
- elif re.search(r'(planet|planets)', question_lower):
307
- return "Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune"
308
-
309
- # For other list questions, use LLM with list-specific prompt
310
- return self._generate_llm_response(question, 'list')
311
-
312
- def _handle_visual_question(self, question: str) -> str:
313
- """Handle questions about images or visual content."""
314
- # Extract key terms from the question to customize the response
315
- key_terms = re.findall(r'[a-zA-Z]{4,}', question)
316
- key_term = key_terms[0].lower() if key_terms else "content"
317
-
318
- # Create a contextually relevant placeholder response
319
- if "graph" in question.lower() or "chart" in question.lower():
320
- return f"The {key_term} graph shows an upward trend with significant data points highlighting the key metrics."
321
-
322
- elif "diagram" in question.lower():
323
- return f"The diagram illustrates the structure and components of the {key_term}, showing how the different parts interact."
324
 
325
- elif "map" in question.lower():
326
- return f"The map displays the geographical distribution of {key_term}, with notable concentrations in the regions."
 
 
327
 
328
- # Default visual response
329
- return f"The image shows {key_term} with distinctive features that directly address the question."
330
-
331
- def _handle_factual_question(self, question: str) -> str:
332
- """Handle factual questions with specific answers."""
333
- question_lower = question.lower()
334
 
335
- # Common factual questions with specific answers
336
- if re.search(r'(capital of france|paris is the capital of)', question_lower):
337
- return "Paris"
 
 
338
 
339
- elif re.search(r'(first president of (the United States|USA|US))', question_lower):
340
- return "George Washington"
341
 
342
- elif re.search(r'(invented (the telephone|telephone))', question_lower):
343
- return "Alexander Graham Bell"
 
344
 
345
- elif re.search(r'(wrote (hamlet|romeo and juliet))', question_lower):
346
- return "William Shakespeare"
347
-
348
- # For other factual questions, use LLM
349
- return self._generate_llm_response(question, 'factual')
350
-
351
- def _handle_general_question(self, question: str) -> str:
352
- """Handle general knowledge questions."""
353
- # Use LLM for general questions
354
- return self._generate_llm_response(question, 'general')
355
-
356
- def _generate_llm_response(self, question: str, question_type: str) -> str:
357
- """Generate a response using the language model."""
358
- if not self.llm_available:
359
- return self._fallback_response(question, question_type)
360
-
361
- try:
362
- # Get the appropriate prompt template
363
- template = self.prompt_templates.get(question_type, self.prompt_templates['general'])
364
- prompt = template.format(question=question)
365
-
366
- # Generate response
367
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
368
- outputs = self.model.generate(
369
- inputs["input_ids"],
370
- max_length=150,
371
- min_length=10,
372
- temperature=0.3,
373
- top_p=0.95,
374
- do_sample=True,
375
- num_return_sequences=1
376
- )
377
 
378
- # Decode and clean up the response
379
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
380
- response = self._clean_response(response)
381
 
382
- return response
383
  except Exception as e:
384
- print(f"Error generating LLM response: {e}")
385
- return self._fallback_response(question, question_type)
386
-
387
- def _clean_response(self, response: str) -> str:
388
- """Clean up the model's response."""
389
- # Remove any prefixes like "Answer:" or "Response:"
390
- for prefix in ["Answer:", "Response:", "A:", "The answer is:", "I think", "I believe"]:
391
- if response.startswith(prefix):
392
- response = response[len(prefix):].strip()
393
-
394
- # Remove first-person references
395
- response = re.sub(r'^I would say that\s+', '', response)
396
- response = re.sub(r'^In my opinion,\s+', '', response)
397
-
398
- # Ensure the response is not too short
399
- if len(response) < 5:
400
- return "Unable to provide a specific answer to this question."
401
-
402
- return response
403
-
404
- def _ensure_concise_answer(self, answer: str, question_type: str) -> str:
405
- """Ensure the answer is concise and specific."""
406
- # Limit answer length based on question type
407
- max_lengths = {
408
- 'calculation': 20,
409
- 'date_time': 30,
410
- 'list': 100,
411
- 'visual': 150,
412
- 'factual': 100,
413
- 'general': 150
414
- }
415
-
416
- max_length = max_lengths.get(question_type, 100)
417
-
418
- # Truncate if too long, but try to keep complete sentences
419
- if len(answer) > max_length:
420
- # Try to find the last sentence boundary before max_length
421
- last_period = answer[:max_length].rfind('.')
422
- if last_period > 0:
423
- answer = answer[:last_period + 1]
424
- else:
425
- answer = answer[:max_length]
426
-
427
- return answer
428
-
429
- def _fallback_response(self, question: str, question_type: str) -> str:
430
- """Provide a fallback response if the model fails."""
431
- # Fallback responses based on question type
432
- fallbacks = {
433
- 'calculation': "42",
434
- 'date_time': "2023-01-01",
435
- 'list': "item1, item2, item3, item4, item5",
436
- 'visual': "The image shows the main subject clearly visible in the center with relevant details surrounding it.",
437
- 'factual': "This is a factual answer to your specific question.",
438
- 'general': "The answer involves multiple factors that must be considered in context."
439
- }
440
-
441
- return fallbacks.get(question_type, "I don't have enough information to answer this question specifically.")
442
 
443
 
444
  class EvaluationRunner:
445
  """
446
- Handles the evaluation process: fetching questions, running the agent,
447
- and submitting answers to the evaluation server.
448
  """
449
 
450
  def __init__(self, api_url="https://agents-course-unit4-scoring.hf.space"):
451
- """Initialize with API endpoints."""
452
  self.api_url = api_url
453
  self.questions_url = f"{api_url}/questions"
454
  self.submit_url = f"{api_url}/submit"
455
  self.results_url = f"{api_url}/results"
456
- self.total_questions = 0
457
  self.correct_answers = 0
 
458
 
459
  def run_evaluation(self,
460
  agent: Any,
461
  username: str,
462
- agent_code_url: str) -> tuple[str, Any]:
463
  """
464
- Run the full evaluation process:
465
- 1. Fetch questions
466
- 2. Run agent on all questions
467
- 3. Submit answers
468
- 4. Check results and count correct answers
469
- 5. Return results
470
  """
471
- # Reset counters
472
- self.total_questions = 0
473
- self.correct_answers = 0
474
-
475
- # Fetch questions
476
  questions_data = self._fetch_questions()
477
- if isinstance(questions_data, str): # Error message
478
  return questions_data, None
479
 
480
- # Run agent on all questions
481
  results_log, answers_payload = self._run_agent_on_questions(agent, questions_data)
482
  if not answers_payload:
483
  return "Agent did not produce any answers to submit.", results_log
484
 
485
- # Submit answers
486
- submission_result = self._submit_answers(username, agent_code_url, answers_payload)
487
-
488
- # Try to fetch results to count correct answers
489
- self._check_results(username)
490
 
491
- # Return results with correct answer count
492
  return submission_result, results_log
493
 
494
  def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
495
- """Fetch questions from the evaluation server."""
496
  print(f"Fetching questions from: {self.questions_url}")
497
  try:
498
  response = requests.get(self.questions_url, timeout=15)
@@ -527,7 +254,7 @@ class EvaluationRunner:
527
  def _run_agent_on_questions(self,
528
  agent: Any,
529
  questions_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
530
- """Run the agent on all questions and collect results."""
531
  results_log = []
532
  answers_payload = []
533
 
@@ -541,13 +268,13 @@ class EvaluationRunner:
541
  continue
542
 
543
  try:
544
- # Call agent with task_id to ensure proper formatting
545
  json_response = agent(question_text, task_id)
546
 
547
- # Parse the JSON response
548
  response_obj = json.loads(json_response)
549
 
550
- # Extract the final_answer for submission
551
  submitted_answer = response_obj.get("final_answer", "")
552
 
553
  answers_payload.append({
@@ -573,18 +300,19 @@ class EvaluationRunner:
573
 
574
  def _submit_answers(self,
575
  username: str,
576
- agent_code_url: str,
577
  answers_payload: List[Dict[str, Any]]) -> str:
578
- """Submit answers to the evaluation server."""
 
579
  submission_data = {
580
  "username": username.strip(),
581
- "agent_code_url": agent_code_url.strip(),
582
  "answers": answers_payload
583
  }
584
 
585
  print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
586
  max_retries = 3
587
- retry_delay = 5 # seconds
588
 
589
  for attempt in range(1, max_retries + 1):
590
  try:
@@ -603,7 +331,7 @@ class EvaluationRunner:
603
  max_score = result.get("max_score")
604
 
605
  if score is not None and max_score is not None:
606
- self.correct_answers = score # Update correct answers count
607
  return f"Evaluation complete! Score: {score}/{max_score}"
608
  else:
609
  print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
@@ -626,11 +354,11 @@ class EvaluationRunner:
626
  else:
627
  return f"Error submitting answers after {max_retries} attempts: {e}"
628
 
629
- # If we get here, all retries failed but didn't raise exceptions
630
  return "Submission Successful, but results are pending!"
631
 
632
  def _check_results(self, username: str) -> None:
633
- """Check results to count correct answers."""
634
  try:
635
  results_url = f"{self.results_url}?username={username}"
636
  print(f"Checking results at: {results_url}")
@@ -656,15 +384,15 @@ class EvaluationRunner:
656
  print(f"Error checking results: {e}")
657
 
658
  def get_correct_answers_count(self) -> int:
659
- """Get the number of correct answers."""
660
  return self.correct_answers
661
 
662
  def get_total_questions_count(self) -> int:
663
- """Get the total number of questions."""
664
  return self.total_questions
665
 
666
  def print_evaluation_summary(self, username: str) -> None:
667
- """Print a summary of the evaluation results."""
668
  print("\n===== EVALUATION SUMMARY =====")
669
  print(f"User: {username}")
670
  print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
@@ -672,74 +400,3 @@ class EvaluationRunner:
672
  print(f"Total Questions: {self.total_questions}")
673
  print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
674
  print("=============================\n")
675
-
676
-
677
- # Example usage and test cases
678
- def test_agent():
679
- """Test the agent with example questions."""
680
- agent = EnhancedGAIAAgent()
681
-
682
- test_questions = [
683
- # Calculation questions
684
- "What is 25 + 17?",
685
- "Calculate the product of 8 and 9",
686
-
687
- # Date/time questions
688
- "What is today's date?",
689
- "What day of the week is it?",
690
-
691
- # List questions
692
- "List five fruits",
693
- "What are the planets in our solar system?",
694
-
695
- # Visual questions
696
- "What does the image show?",
697
- "Describe the chart in the image",
698
-
699
- # Factual questions
700
- "Who was the first president of the United States?",
701
- "What is the capital of France?",
702
- "How does photosynthesis work?",
703
-
704
- # General questions
705
- "Why is the sky blue?",
706
- "What are the implications of quantum mechanics?"
707
- ]
708
-
709
- print("\n=== AGENT TEST RESULTS ===")
710
- correct_count = 0
711
- total_count = len(test_questions)
712
-
713
- for question in test_questions:
714
- # Generate a mock task_id for testing
715
- task_id = f"test_{hash(question) % 10000}"
716
-
717
- # Get JSON response with final_answer
718
- json_response = agent(question, task_id)
719
-
720
- print(f"\nQ: {question}")
721
- print(f"Response: {json_response}")
722
-
723
- # Parse and print the final_answer for clarity
724
- try:
725
- response_obj = json.loads(json_response)
726
- final_answer = response_obj.get('final_answer', '')
727
- print(f"Final Answer: {final_answer}")
728
-
729
- # For testing purposes, simulate correct answers
730
- if len(final_answer) > 0 and not final_answer.startswith("AGENT ERROR"):
731
- correct_count += 1
732
- except:
733
- print("Error parsing JSON response")
734
-
735
- # Print test summary with correct answer count
736
- print("\n===== TEST SUMMARY =====")
737
- print(f"Correct Answers: {correct_count}/{total_count}")
738
- print(f"Accuracy: {(correct_count / total_count * 100):.1f}%")
739
- print("=======================\n")
740
-
741
- return "Test completed successfully"
742
-
743
-
744
- if __name__ == "__main__":
745
- test_agent()
 
1
  """
2
+ Улучшенный GAIA Agent с поддержкой кэширования ответов
3
  """
4
 
5
  import os
 
 
6
  import json
7
+ import time
 
 
8
  import torch
9
+ import requests
10
+ from typing import List, Dict, Any, Optional, Union
11
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
+
13
+ # Константы
14
+ CACHE_FILE = "gaia_answers_cache.json"
15
 
16
  class EnhancedGAIAAgent:
17
  """
18
+ Улучшенный агент для Hugging Face GAIA с поддержкой кэширования ответов
 
19
  """
20
 
21
+ def __init__(self, model_name="google/flan-t5-small", use_cache=True):
22
+ """
23
+ Инициализация агента с моделью и кэшем
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ Args:
26
+ model_name: Название модели для загрузки
27
+ use_cache: Использовать ли кэширование ответов
28
+ """
29
+ print(f"Initializing EnhancedGAIAAgent with model: {model_name}")
30
+ self.model_name = model_name
31
+ self.use_cache = use_cache
32
+ self.cache = self._load_cache() if use_cache else {}
 
 
33
 
34
+ # Загружаем модель и токенизатор
35
+ print("Loading tokenizer...")
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
37
+ print("Loading model...")
38
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
39
+ print("Model and tokenizer loaded successfully")
40
+
41
+ def _load_cache(self) -> Dict[str, str]:
42
+ """
43
+ Загружает кэш ответов из файла
44
 
45
+ Returns:
46
+ Dict[str, str]: Словарь с кэшированными ответами
47
+ """
48
+ if os.path.exists(CACHE_FILE):
49
+ try:
50
+ with open(CACHE_FILE, 'r', encoding='utf-8') as f:
51
+ print(f"Loading cache from {CACHE_FILE}")
52
+ return json.load(f)
53
+ except Exception as e:
54
+ print(f"Error loading cache: {e}")
55
+ return {}
56
+ else:
57
+ print(f"Cache file {CACHE_FILE} not found, creating new cache")
58
+ return {}
59
+
60
+ def _save_cache(self) -> None:
61
+ """
62
+ Сохраняет кэш ответов в файл
63
+ """
64
  try:
65
+ with open(CACHE_FILE, 'w', encoding='utf-8') as f:
66
+ json.dump(self.cache, f, ensure_ascii=False, indent=2)
67
+ print(f"Cache saved to {CACHE_FILE}")
 
 
68
  except Exception as e:
69
+ print(f"Error saving cache: {e}")
 
 
 
70
 
71
+ def _classify_question(self, question: str) -> str:
72
  """
73
+ Классифицирует вопрос по типу для лучшего форматирования ответа
74
 
75
  Args:
76
+ question: Текст вопроса
 
77
 
78
  Returns:
79
+ str: Тип вопроса (factual, calculation, list, date_time, etc.)
80
  """
81
+ # Простая эвристическая классификация
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  question_lower = question.lower()
83
 
84
+ if any(word in question_lower for word in ["calculate", "sum", "product", "divide", "multiply", "add", "subtract", "how many"]):
85
+ return "calculation"
86
+ elif any(word in question_lower for word in ["list", "enumerate", "items", "elements"]):
87
+ return "list"
88
+ elif any(word in question_lower for word in ["date", "time", "day", "month", "year", "when"]):
89
+ return "date_time"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  else:
91
+ return "factual"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ def _format_answer(self, raw_answer: str, question_type: str) -> str:
94
+ """
95
+ Форматирует ответ в соответствии с типом вопроса
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ Args:
98
+ raw_answer: Необработанный ответ от модели
99
+ question_type: Тип вопроса
 
 
 
 
 
 
 
 
100
 
101
+ Returns:
102
+ str: Отформатированный ответ
103
+ """
104
+ # Удаляем лишние пробелы и переносы строк
105
+ answer = raw_answer.strip()
106
+
107
+ # Удаляем префиксы, которые часто добавляет модель
108
+ prefixes = ["Answer:", "The answer is:", "I think", "I believe", "According to", "Based on"]
109
+ for prefix in prefixes:
110
+ if answer.startswith(prefix):
111
+ answer = answer[len(prefix):].strip()
112
+
113
+ # Специфическое форматирование в зависимости от типа вопроса
114
+ if question_type == "calculation":
115
+ # Для числовых ответов удаляем лишний текст
116
+ # Оставляем только числа, если они есть
117
+ import re
118
+ numbers = re.findall(r'-?\d+\.?\d*', answer)
119
+ if numbers:
120
+ answer = numbers[0]
121
+ elif question_type == "list":
122
+ # Для списков убеждаемся, что элементы разделены запятыми
123
+ if "," not in answer and " " in answer:
124
+ items = [item.strip() for item in answer.split() if item.strip()]
125
+ answer = ", ".join(items)
126
 
127
+ return answer
 
128
 
129
+ def __call__(self, question: str, task_id: Optional[str] = None) -> str:
130
+ """
131
+ Обрабатывает вопрос и возвращает ответ
132
 
133
+ Args:
134
+ question: Текст вопроса
135
+ task_id: Идентификатор задачи (опционально)
 
 
 
 
 
 
 
 
 
136
 
137
+ Returns:
138
+ str: Ответ в формате JSON с ключом final_answer
139
+ """
140
+ # Создаем ключ для кэша (используем task_id, если доступен)
141
+ cache_key = task_id if task_id else question
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # Проверяем наличие ответа в кэше
144
+ if self.use_cache and cache_key in self.cache:
145
+ print(f"Cache hit for question: {question[:50]}...")
146
+ return self.cache[cache_key]
147
 
148
+ # Классифицируем вопрос
149
+ question_type = self._classify_question(question)
150
+ print(f"Processing question: {question[:100]}...")
151
+ print(f"Classified as: {question_type}")
 
 
152
 
153
+ try:
154
+ # Генерируем ответ с помощью модели
155
+ inputs = self.tokenizer(question, return_tensors="pt")
156
+ outputs = self.model.generate(**inputs, max_length=100)
157
+ raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
158
 
159
+ # Форматируем ответ
160
+ formatted_answer = self._format_answer(raw_answer, question_type)
161
 
162
+ # Формируем JSON-ответ
163
+ result = {"final_answer": formatted_answer}
164
+ json_response = json.dumps(result)
165
 
166
+ # Сохраняем в кэш
167
+ if self.use_cache:
168
+ self.cache[cache_key] = json_response
169
+ self._save_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ return json_response
 
 
172
 
 
173
  except Exception as e:
174
+ error_msg = f"Error generating answer: {e}"
175
+ print(error_msg)
176
+ return json.dumps({"final_answer": f"AGENT ERROR: {e}"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  class EvaluationRunner:
180
  """
181
+ Обрабатывает процесс оценки: получение вопросов, запуск агента,
182
+ и отправку ответов на сервер оценки.
183
  """
184
 
185
  def __init__(self, api_url="https://agents-course-unit4-scoring.hf.space"):
186
+ """Инициализация с API endpoints."""
187
  self.api_url = api_url
188
  self.questions_url = f"{api_url}/questions"
189
  self.submit_url = f"{api_url}/submit"
190
  self.results_url = f"{api_url}/results"
 
191
  self.correct_answers = 0
192
+ self.total_questions = 0
193
 
194
  def run_evaluation(self,
195
  agent: Any,
196
  username: str,
197
+ agent_code: str) -> tuple[str, List[Dict[str, Any]]]:
198
  """
199
+ Запускает полный процесс оценки:
200
+ 1. Получает вопросы
201
+ 2. Запускает агента на всех вопросах
202
+ 3. Отправляет ответы
203
+ 4. Возвращает результаты
 
204
  """
205
+ # Получаем вопросы
 
 
 
 
206
  questions_data = self._fetch_questions()
207
+ if isinstance(questions_data, str): # Сообщение об ошибке
208
  return questions_data, None
209
 
210
+ # Запускаем агента на всех вопросах
211
  results_log, answers_payload = self._run_agent_on_questions(agent, questions_data)
212
  if not answers_payload:
213
  return "Agent did not produce any answers to submit.", results_log
214
 
215
+ # Отправляем ответы с логикой повторных попыток
216
+ submission_result = self._submit_answers(username, agent_code, answers_payload)
 
 
 
217
 
218
+ # Возвращаем результаты
219
  return submission_result, results_log
220
 
221
  def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
222
+ """Получает вопросы с сервера оценки."""
223
  print(f"Fetching questions from: {self.questions_url}")
224
  try:
225
  response = requests.get(self.questions_url, timeout=15)
 
254
  def _run_agent_on_questions(self,
255
  agent: Any,
256
  questions_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
257
+ """Запускает аге��та на всех вопросах и собирает результаты."""
258
  results_log = []
259
  answers_payload = []
260
 
 
268
  continue
269
 
270
  try:
271
+ # Вызываем агента с task_id для правильного форматирования
272
  json_response = agent(question_text, task_id)
273
 
274
+ # Парсим JSON-ответ
275
  response_obj = json.loads(json_response)
276
 
277
+ # Извлекаем final_answer для отправки
278
  submitted_answer = response_obj.get("final_answer", "")
279
 
280
  answers_payload.append({
 
300
 
301
  def _submit_answers(self,
302
  username: str,
303
+ agent_code: str,
304
  answers_payload: List[Dict[str, Any]]) -> str:
305
+ """Отправляет ответы на сервер оценки."""
306
+ # ИСПРАВЛЕНО: Используем agent_code вместо agent_code_url
307
  submission_data = {
308
  "username": username.strip(),
309
+ "agent_code": agent_code.strip(), # Исправлено здесь
310
  "answers": answers_payload
311
  }
312
 
313
  print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
314
  max_retries = 3
315
+ retry_delay = 5 # секунд
316
 
317
  for attempt in range(1, max_retries + 1):
318
  try:
 
331
  max_score = result.get("max_score")
332
 
333
  if score is not None and max_score is not None:
334
+ self.correct_answers = score # Обновляем счетчик правильных ответов
335
  return f"Evaluation complete! Score: {score}/{max_score}"
336
  else:
337
  print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
 
354
  else:
355
  return f"Error submitting answers after {max_retries} attempts: {e}"
356
 
357
+ # Если мы здесь, все попытки не удались, но не вызвали исключений
358
  return "Submission Successful, but results are pending!"
359
 
360
  def _check_results(self, username: str) -> None:
361
+ """Проверяет результаты для подсчета правильных ответов."""
362
  try:
363
  results_url = f"{self.results_url}?username={username}"
364
  print(f"Checking results at: {results_url}")
 
384
  print(f"Error checking results: {e}")
385
 
386
  def get_correct_answers_count(self) -> int:
387
+ """Возвращает количество правильных ответов."""
388
  return self.correct_answers
389
 
390
  def get_total_questions_count(self) -> int:
391
+ """Возвращает общее количество вопросов."""
392
  return self.total_questions
393
 
394
  def print_evaluation_summary(self, username: str) -> None:
395
+ """Выводит сводку результатов оценки."""
396
  print("\n===== EVALUATION SUMMARY =====")
397
  print(f"User: {username}")
398
  print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
 
400
  print(f"Total Questions: {self.total_questions}")
401
  print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
402
  print("=============================\n")