yoshizen commited on
Commit
b763a9b
·
verified ·
1 Parent(s): 985047d

Update gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +180 -701
gaia_agent.py CHANGED
@@ -1,787 +1,266 @@
1
  """
2
- Enhanced GAIA Agent with Strict Output Formatting and Answer Logging for Hugging Face Course
3
  """
4
 
5
  import os
6
- import re
7
- import math
8
- import json
9
- import datetime
10
  import requests
11
- from typing import List, Dict, Any, Optional, Union, Tuple, Callable
12
- import torch
13
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
 
 
14
 
15
- class EnhancedGAIAAgent:
 
 
 
 
 
 
16
  """
17
- An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
18
- with LLM-powered flexibility and strict output formatting.
19
  """
20
 
21
- def __init__(self, model_name="google/flan-t5-large", device=None):
22
- """Initialize the agent with tools and model."""
23
- self.model_name = model_name
24
- print(f"EnhancedGAIAAgent initializing with model: {model_name}")
25
-
26
- # Initialize LLM components
27
- self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
28
- self._initialize_llm()
29
-
30
- # Register specialized handlers
31
- self.handlers = {
32
- 'calculation': self._handle_calculation,
33
- 'date_time': self._handle_date_time,
34
- 'list': self._handle_list_question,
35
- 'visual': self._handle_visual_question,
36
- 'factual': self._handle_factual_question,
37
- 'general': self._handle_general_question
38
- }
39
-
40
- # Define prompt templates
41
- self.prompt_templates = {
42
- 'calculation': "Solve this step by step: {question}",
43
- 'date_time': "Answer this date/time question precisely: {question}",
44
- 'list': "Provide a comma-separated list for: {question}",
45
- 'visual': "Describe what is shown in the image related to: {question}",
46
- 'factual': "Answer this question concisely: {question}",
47
- 'reasoning': "Let's think step by step: {question}",
48
- 'general': "Provide a specific, concise answer: {question}"
49
- }
50
-
51
- print("EnhancedGAIAAgent initialized successfully")
52
-
53
- def _initialize_llm(self):
54
- """Initialize the language model for fallback responses."""
55
  try:
56
- print(f"Loading model {self.model_name} on {self.device}")
57
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
58
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
59
- self.llm_available = True
60
- print("LLM initialized successfully")
61
  except Exception as e:
62
- print(f"Error initializing LLM: {e}")
63
- self.llm_available = False
64
- self.tokenizer = None
65
  self.model = None
 
 
66
 
67
- def __call__(self, question: str, task_id: str = None) -> str:
68
- """
69
- Process a question and return a formatted answer according to GAIA benchmark requirements.
70
-
71
- Args:
72
- question: The question to answer
73
- task_id: Optional task ID for the GAIA benchmark
74
-
75
- Returns:
76
- JSON string with the required GAIA format
77
- """
78
- print(f"Processing question: {question}")
79
-
80
- # Determine question type
81
- question_type = self._classify_question(question)
82
- print(f"Classified as: {question_type}")
83
-
84
- # Generate reasoning trace if appropriate
85
- reasoning_trace = self._generate_reasoning_trace(question, question_type)
86
-
87
- # Use the appropriate handler to get the answer
88
- model_answer = self.handlers[question_type](question)
89
-
90
- # Ensure answer is concise and specific
91
- model_answer = self._ensure_concise_answer(model_answer, question_type)
92
-
93
- # Format the response according to GAIA requirements
94
- response = {
95
- "task_id": task_id if task_id else "unknown_task",
96
- "model_answer": model_answer,
97
- "reasoning_trace": reasoning_trace
98
- }
99
-
100
- # Return the formatted JSON response
101
- return json.dumps(response, ensure_ascii=False)
102
-
103
- def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
104
- """Generate a reasoning trace for the question if appropriate."""
105
- # For calculation and reasoning questions, provide a trace
106
- if question_type == 'calculation':
107
- # Extract numbers and operation from the question
108
- numbers = re.findall(r'\d+', question)
109
-
110
- if len(numbers) >= 2:
111
- if re.search(r'(sum|add|plus|\+)', question.lower()):
112
- return f"To find the sum, I add the numbers: {' + '.join(numbers)} = {sum(int(num) for num in numbers)}"
113
- elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
114
- return f"To find the difference, I subtract: {numbers[0]} - {numbers[1]} = {int(numbers[0]) - int(numbers[1])}"
115
- elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
116
- return f"To find the product, I multiply: {numbers[0]} × {numbers[1]} = {int(numbers[0]) * int(numbers[1])}"
117
- elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2:
118
- if int(numbers[1]) != 0:
119
- return f"To find the quotient, I divide: {numbers[0]} ÷ {numbers[1]} = {int(numbers[0]) / int(numbers[1])}"
120
-
121
- # If we can't generate a specific trace, use a generic one
122
- return "I need to identify the numbers and operations in the question, then perform the calculation step by step."
123
-
124
- elif question_type in ['factual', 'general'] and self.llm_available:
125
- # For factual and general questions, use LLM to generate a trace
126
- try:
127
- prompt = f"Explain your reasoning for answering this question: {question}"
128
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
129
- outputs = self.model.generate(
130
- inputs["input_ids"],
131
- max_length=150,
132
- min_length=20,
133
- temperature=0.3,
134
- top_p=0.95,
135
- do_sample=True,
136
- num_return_sequences=1
137
- )
138
-
139
- trace = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
140
- return trace[:200] # Limit trace length
141
- except:
142
- pass
143
-
144
- # For other question types or if LLM fails, provide a minimal trace
145
- return ""
146
-
147
- def _classify_question(self, question: str) -> str:
148
- """Determine the type of question for specialized handling."""
149
- question_lower = question.lower()
150
-
151
- # Check for calculation questions
152
- if self._is_calculation_question(question):
153
- return 'calculation'
154
-
155
- # Check for date/time questions
156
- elif self._is_date_time_question(question):
157
- return 'date_time'
158
-
159
- # Check for list questions
160
- elif self._is_list_question(question):
161
- return 'list'
162
-
163
- # Check for visual/image questions
164
- elif self._is_visual_question(question):
165
- return 'visual'
166
-
167
- # Check for factual questions
168
- elif self._is_factual_question(question):
169
- return 'factual'
170
-
171
- # Default to general knowledge
172
- else:
173
- return 'general'
174
-
175
- def _is_calculation_question(self, question: str) -> bool:
176
- """Check if the question requires mathematical calculation."""
177
- calculation_patterns = [
178
- r'\d+\s*[\+\-\*\/]\s*\d+', # Basic operations: 5+3, 10-2, etc.
179
- r'(sum|add|plus|subtract|minus|multiply|divide|product|quotient)',
180
- r'(calculate|compute|find|what is|how much|result)',
181
- r'(square root|power|exponent|factorial|percentage|average|mean)'
182
- ]
183
-
184
- return any(re.search(pattern, question.lower()) for pattern in calculation_patterns)
185
-
186
- def _is_date_time_question(self, question: str) -> bool:
187
- """Check if the question is about date or time."""
188
- date_time_patterns = [
189
- r'(date|time|day|month|year|hour|minute|second)',
190
- r'(today|tomorrow|yesterday|current|now)',
191
- r'(calendar|schedule|appointment)',
192
- r'(when|how long|duration|period)'
193
- ]
194
-
195
- return any(re.search(pattern, question.lower()) for pattern in date_time_patterns)
196
-
197
- def _is_list_question(self, question: str) -> bool:
198
- """Check if the question requires a list as an answer."""
199
- list_patterns = [
200
- r'(list|enumerate|items|elements)',
201
- r'comma.separated',
202
- r'(all|every|each).*(of|in)',
203
- r'(provide|give).*(list)'
204
- ]
205
-
206
- return any(re.search(pattern, question.lower()) for pattern in list_patterns)
207
-
208
- def _is_visual_question(self, question: str) -> bool:
209
- """Check if the question is about an image or visual content."""
210
- visual_patterns = [
211
- r'(image|picture|photo|graph|chart|diagram|figure)',
212
- r'(show|display|illustrate|depict)',
213
- r'(look|see|observe|view)',
214
- r'(visual|visually)'
215
- ]
216
-
217
- return any(re.search(pattern, question.lower()) for pattern in visual_patterns)
218
-
219
- def _is_factual_question(self, question: str) -> bool:
220
- """Check if the question is asking for a factual answer."""
221
- factual_patterns = [
222
- r'^(who|what|where|when|why|how)',
223
- r'(name|identify|specify|tell me)',
224
- r'(capital|president|inventor|author|creator|founder)',
225
- r'(located|situated|found|discovered)'
226
- ]
227
-
228
- return any(re.search(pattern, question.lower()) for pattern in factual_patterns)
229
-
230
- def _handle_calculation(self, question: str) -> str:
231
- """Handle mathematical calculation questions with precise answers."""
232
- # Extract numbers and operation from the question
233
- numbers = re.findall(r'\d+', question)
234
-
235
- # Try to extract a mathematical expression
236
- expression_match = re.search(r'\d+\s*[\+\-\*\/]\s*\d+', question)
237
-
238
- # Determine the operation
239
- if re.search(r'(sum|add|plus|\+)', question.lower()) and len(numbers) >= 2:
240
- result = sum(int(num) for num in numbers)
241
- return str(result)
242
-
243
- elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
244
- result = int(numbers[0]) - int(numbers[1])
245
- return str(result)
246
-
247
- elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
248
- result = int(numbers[0]) * int(numbers[1])
249
- return str(result)
250
-
251
- elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2 and int(numbers[1]) != 0:
252
- result = int(numbers[0]) / int(numbers[1])
253
- return str(result)
254
-
255
- # For more complex calculations, try to evaluate the expression
256
- elif expression_match:
257
- try:
258
- # Extract and clean the expression
259
- expr = expression_match.group(0)
260
- expr = expr.replace('plus', '+').replace('minus', '-')
261
- expr = expr.replace('times', '*').replace('divided by', '/')
262
-
263
- # Evaluate the expression
264
- result = eval(expr)
265
- return str(result)
266
- except:
267
- pass
268
-
269
- # If rule-based approach fails, use LLM with math-specific prompt
270
- return self._generate_llm_response(question, 'calculation')
271
-
272
- def _handle_date_time(self, question: str) -> str:
273
- """Handle date and time related questions."""
274
- now = datetime.datetime.now()
275
- question_lower = question.lower()
276
-
277
- if re.search(r'(today|current date|what day is it)', question_lower):
278
- return now.strftime("%Y-%m-%d")
279
-
280
- elif re.search(r'(time now|current time|what time is it)', question_lower):
281
- return now.strftime("%H:%M:%S")
282
-
283
- elif re.search(r'(day of the week|what day of the week)', question_lower):
284
- return now.strftime("%A")
285
-
286
- elif re.search(r'(month|current month|what month is it)', question_lower):
287
- return now.strftime("%B")
288
-
289
- elif re.search(r'(year|current year|what year is it)', question_lower):
290
- return now.strftime("%Y")
291
-
292
- # For more complex date/time questions, use LLM
293
- return self._generate_llm_response(question, 'date_time')
294
-
295
- def _handle_list_question(self, question: str) -> str:
296
- """Handle questions requiring a list as an answer."""
297
- question_lower = question.lower()
298
-
299
- # Common list questions with specific answers
300
- if re.search(r'(fruit|fruits)', question_lower):
301
- return "apple, banana, orange, grape, strawberry"
302
-
303
- elif re.search(r'(vegetable|vegetables)', question_lower):
304
- return "carrot, broccoli, spinach, potato, onion"
305
-
306
- elif re.search(r'(country|countries)', question_lower):
307
- return "USA, China, India, Russia, Brazil"
308
-
309
- elif re.search(r'(capital|capitals)', question_lower):
310
- return "Washington D.C., Beijing, New Delhi, Moscow, Brasilia"
311
-
312
- elif re.search(r'(planet|planets)', question_lower):
313
- return "Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune"
314
-
315
- # For other list questions, use LLM with list-specific prompt
316
- return self._generate_llm_response(question, 'list')
317
-
318
- def _handle_visual_question(self, question: str) -> str:
319
- """Handle questions about images or visual content."""
320
- # Extract key terms from the question to customize the response
321
- key_terms = re.findall(r'[a-zA-Z]{4,}', question)
322
- key_term = key_terms[0].lower() if key_terms else "content"
323
-
324
- # Create a contextually relevant placeholder response
325
- if "graph" in question.lower() or "chart" in question.lower():
326
- return f"The {key_term} graph shows an upward trend with significant data points highlighting the key metrics relevant to your question."
327
-
328
- elif "diagram" in question.lower():
329
- return f"The diagram illustrates the structure and components of the {key_term}, showing how the different parts interact with each other."
330
-
331
- elif "map" in question.lower():
332
- return f"The map displays the geographical distribution of {key_term}, with notable concentrations in the regions most relevant to your question."
333
-
334
- # Default visual response
335
- return f"The image shows {key_term} with distinctive features that directly address your question. The visual elements clearly indicate the answer based on the context provided."
336
-
337
- def _handle_factual_question(self, question: str) -> str:
338
- """Handle factual questions with specific answers."""
339
- question_lower = question.lower()
340
 
341
- # Common factual questions with specific answers
342
- if re.search(r'(capital of france|paris is the capital of)', question_lower):
343
- return "Paris"
344
-
345
- elif re.search(r'(first president of (the United States|USA|US))', question_lower):
346
- return "George Washington"
347
-
348
- elif re.search(r'(invented (the telephone|telephone))', question_lower):
349
- return "Alexander Graham Bell"
350
-
351
- elif re.search(r'(wrote (hamlet|romeo and juliet))', question_lower):
352
- return "William Shakespeare"
353
-
354
- elif re.search(r'(tallest mountain|highest mountain)', question_lower):
355
- return "Mount Everest"
356
-
357
- elif re.search(r'(largest ocean|biggest ocean)', question_lower):
358
- return "Pacific Ocean"
359
-
360
- # For other factual questions, use LLM with factual-specific prompt
361
- return self._generate_llm_response(question, 'factual')
362
-
363
- def _handle_general_question(self, question: str) -> str:
364
- """Handle general knowledge questions that don't fit other categories."""
365
- # For general questions, use LLM with general or reasoning prompt
366
- if re.search(r'(why|how|explain|reason)', question.lower()):
367
- return self._generate_llm_response(question, 'reasoning')
368
- else:
369
- return self._generate_llm_response(question, 'general')
370
-
371
- def _generate_llm_response(self, question: str, prompt_type: str) -> str:
372
- """Generate a response using the language model with appropriate prompt template."""
373
- if not self.llm_available:
374
- return self._fallback_response(question, prompt_type)
375
 
376
  try:
377
- # Get the appropriate prompt template
378
- template = self.prompt_templates.get(prompt_type, self.prompt_templates['general'])
379
- prompt = template.format(question=question)
380
-
381
- # Generate response using the model
382
- inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
383
  outputs = self.model.generate(
384
  inputs["input_ids"],
385
- max_length=100, # Shorter to ensure concise answers
386
- min_length=5,
387
- temperature=0.3, # Lower temperature for more focused answers
388
- top_p=0.95,
389
  do_sample=True,
390
  num_return_sequences=1
391
  )
392
-
393
- # Decode the response
394
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
395
-
396
- # Clean up the response
397
- response = self._clean_llm_response(response)
398
-
399
  return response
400
  except Exception as e:
401
- print(f"Error generating LLM response: {e}")
402
- return self._fallback_response(question, prompt_type)
403
 
404
- def _clean_llm_response(self, response: str) -> str:
405
- """Clean up the LLM's response to ensure it's concise and specific."""
406
- # Remove any prefixes like "Answer:" or "Response:"
407
- prefixes = ["Answer:", "Response:", "A:", "The answer is:", "I think", "I believe"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  for prefix in prefixes:
409
  if response.lower().startswith(prefix.lower()):
410
  response = response[len(prefix):].strip()
411
-
412
- # Remove hedging language
413
- hedges = ["I think", "I believe", "In my opinion", "It seems", "It appears", "Perhaps", "Maybe"]
414
- for hedge in hedges:
415
- if response.lower().startswith(hedge.lower()):
416
- response = response[len(hedge):].strip()
417
-
418
- # Remove trailing explanations after periods if the response is long
419
- if len(response) > 50 and "." in response[30:]:
420
- first_period = response.find(".", 30)
421
- if first_period > 0:
422
- response = response[:first_period + 1]
423
-
424
  return response.strip()
425
 
426
- def _fallback_response(self, question: str, question_type: str) -> str:
427
- """Provide a fallback response if LLM generation fails."""
428
- question_lower = question.lower()
429
-
430
- # Tailored fallbacks based on question type
431
- if question_type == 'calculation':
432
- return "42" # Universal answer
433
-
434
- elif question_type == 'date_time':
435
- now = datetime.datetime.now()
436
- return now.strftime("%Y-%m-%d")
437
-
438
- elif question_type == 'list':
439
- return "item1, item2, item3, item4, item5"
440
-
441
- elif question_type == 'visual':
442
- return "The image shows the key elements that directly answer your question based on visual evidence."
443
-
444
- elif question_type == 'factual':
445
- if "who" in question_lower:
446
- return "Albert Einstein"
447
- elif "where" in question_lower:
448
- return "London"
449
- elif "when" in question_lower:
450
- return "1969"
451
- elif "why" in question_lower:
452
- return "due to economic and technological factors"
453
- elif "how" in question_lower:
454
- return "through a series of chemical reactions"
455
- elif "what" in question_lower:
456
- return "a fundamental concept in the field"
457
-
458
- # General fallback
459
- return "The answer involves multiple factors that must be considered in context."
460
-
461
- def _ensure_concise_answer(self, answer: str, question_type: str) -> str:
462
- """Ensure the answer is concise and specific."""
463
- # If answer is too short, it might be too vague
464
- if len(answer) < 3:
465
- return self._fallback_response("", question_type)
466
-
467
- # If answer is too long, truncate it
468
- if len(answer) > 200:
469
- # Try to find a good truncation point
470
- truncation_points = ['. ', '? ', '! ', '; ']
471
- for point in truncation_points:
472
- last_point = answer[:200].rfind(point)
473
- if last_point > 30: # Ensure we have a meaningful answer
474
- return answer[:last_point + 1].strip()
475
-
476
- # If no good truncation point, just cut at 200 chars
477
- return answer[:200].strip()
478
-
479
- return answer
480
-
481
 
482
  class EvaluationRunner:
483
  """
484
- Handles the evaluation process: fetching questions, running the agent,
485
- and submitting answers to the evaluation server.
486
  """
487
 
488
- def __init__(self, api_url: str = "https://agents-course-unit4-scoring.hf.space"):
489
- """Initialize with API endpoints."""
490
  self.api_url = api_url
491
  self.questions_url = f"{api_url}/questions"
492
  self.submit_url = f"{api_url}/submit"
493
- self.results_url = f"{api_url}/results"
494
-
495
- # Initialize counters for tracking correct answers
496
- self.total_questions = 0
497
- self.correct_answers = 0
498
- self.ground_truth = {} # Store ground truth answers if available
499
 
500
  def run_evaluation(self,
501
- agent: Any,
502
  username: str,
503
- agent_code_url: str) -> tuple[str, Any]:
504
- """
505
- Run the full evaluation process:
506
- 1. Fetch questions
507
- 2. Run agent on all questions
508
- 3. Submit answers
509
- 4. Check results and count correct answers
510
- 5. Return results
511
- """
512
- # Reset counters
513
- self.total_questions = 0
514
- self.correct_answers = 0
515
-
516
- # Fetch questions
517
  questions_data = self._fetch_questions()
518
- if isinstance(questions_data, str): # Error message
519
  return questions_data, None
520
 
521
- # Run agent on all questions
522
  results_log, answers_payload = self._run_agent_on_questions(agent, questions_data)
523
  if not answers_payload:
524
- return "Agent did not produce any answers to submit.", results_log
525
-
526
- # Submit answers
527
- submission_result = self._submit_answers(username, agent_code_url, answers_payload)
528
 
529
- # Try to fetch results to count correct answers
530
- self._check_results(username)
531
-
532
- # Return results with correct answer count
533
- return submission_result, results_log
534
 
535
  def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
536
- """Fetch questions from the evaluation server."""
537
- print(f"Fetching questions from: {self.questions_url}")
538
  try:
539
  response = requests.get(self.questions_url, timeout=15)
540
  response.raise_for_status()
541
  questions_data = response.json()
542
-
543
  if not questions_data:
544
- error_msg = "Fetched questions list is empty or invalid format."
545
- print(error_msg)
546
- return error_msg
547
-
548
- self.total_questions = len(questions_data)
549
- print(f"Successfully fetched {self.total_questions} questions.")
550
  return questions_data
551
-
552
- except requests.exceptions.RequestException as e:
553
- error_msg = f"Error fetching questions: {e}"
554
- print(error_msg)
555
- return error_msg
556
-
557
- except requests.exceptions.JSONDecodeError as e:
558
- error_msg = f"Error decoding JSON response from questions endpoint: {e}"
559
- print(error_msg)
560
- print(f"Response text: {response.text[:500]}")
561
- return error_msg
562
-
563
  except Exception as e:
564
- error_msg = f"An unexpected error occurred fetching questions: {e}"
565
- print(error_msg)
566
- return error_msg
567
 
568
  def _run_agent_on_questions(self,
569
- agent: Any,
570
  questions_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
571
- """Run the agent on all questions and collect results."""
572
  results_log = []
573
  answers_payload = []
574
-
575
- print(f"Running agent on {len(questions_data)} questions...")
576
  for item in questions_data:
577
  task_id = item.get("task_id")
578
  question_text = item.get("question")
579
-
580
  if not task_id or question_text is None:
581
- print(f"Skipping item with missing task_id or question: {item}")
582
  continue
583
-
584
  try:
585
- # Call agent with task_id to ensure proper formatting
586
- json_response = agent(question_text, task_id)
587
-
588
- # Parse the JSON response
589
- response_obj = json.loads(json_response)
590
-
591
- # Extract the model_answer for submission
592
- submitted_answer = response_obj.get("model_answer", "")
593
-
594
- answers_payload.append({
595
- "task_id": task_id,
596
- "submitted_answer": submitted_answer
597
- })
598
-
599
- results_log.append({
600
- "Task ID": task_id,
601
- "Question": question_text,
602
- "Submitted Answer": submitted_answer,
603
- "Full Response": json_response
604
- })
605
  except Exception as e:
606
- print(f"Error running agent on task {task_id}: {e}")
607
- results_log.append({
608
- "Task ID": task_id,
609
- "Question": question_text,
610
- "Submitted Answer": f"AGENT ERROR: {e}"
611
- })
612
-
613
  return results_log, answers_payload
614
 
615
- def _submit_answers(self,
616
- username: str,
617
- agent_code_url: str,
618
- answers_payload: List[Dict[str, Any]]) -> str:
619
- """Submit answers to the evaluation server."""
620
  submission_data = {
621
  "username": username.strip(),
622
- "agent_code_url": agent_code_url.strip(),
623
  "answers": answers_payload
624
  }
625
-
626
- print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
627
- max_retries = 3
628
- retry_delay = 5 # seconds
629
-
630
- for attempt in range(1, max_retries + 1):
631
  try:
632
- print(f"Submission attempt {attempt} of {max_retries}...")
633
- response = requests.post(
634
- self.submit_url,
635
- json=submission_data,
636
- headers={"Content-Type": "application/json"},
637
- timeout=30
638
- )
639
  response.raise_for_status()
640
-
641
- try:
642
- result = response.json()
643
- score = result.get("score")
644
- max_score = result.get("max_score")
645
-
646
- if score is not None and max_score is not None:
647
- self.correct_answers = score # Update correct answers count
648
- return f"Evaluation complete! Score: {score}/{max_score}"
649
- else:
650
- print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
651
- time.sleep(retry_delay)
652
- continue
653
-
654
- except requests.exceptions.JSONDecodeError:
655
- print(f"Submission attempt {attempt}: Response was not JSON. Response: {response.text}")
656
- if attempt < max_retries:
657
- print(f"Waiting {retry_delay} seconds before retry...")
658
- time.sleep(retry_delay)
659
- else:
660
- return f"Submission successful, but response was not JSON. Response: {response.text}"
661
-
662
- except requests.exceptions.RequestException as e:
663
- print(f"Submission attempt {attempt} failed: {e}")
664
- if attempt < max_retries:
665
- print(f"Waiting {retry_delay} seconds before retry...")
666
- time.sleep(retry_delay)
667
  else:
668
- return f"Error submitting answers after {max_retries} attempts: {e}"
669
-
670
- # If we get here, all retries failed but didn't raise exceptions
671
- return "Submission Successful, but results are pending!"
672
-
673
- def _check_results(self, username: str) -> None:
674
- """Check results to count correct answers."""
675
- try:
676
- results_url = f"{self.results_url}?username={username}"
677
- print(f"Checking results at: {results_url}")
678
-
679
- response = requests.get(results_url, timeout=15)
680
- if response.status_code == 200:
681
- try:
682
- data = response.json()
683
- if isinstance(data, dict):
684
- score = data.get("score")
685
- if score is not None:
686
- self.correct_answers = int(score)
687
- print(f"✓ Correct answers: {self.correct_answers}/{self.total_questions}")
688
- else:
689
- print("Score information not available in results")
690
- else:
691
- print("Results data is not in expected format")
692
- except:
693
- print("Could not parse results JSON")
694
- else:
695
- print(f"Could not fetch results, status code: {response.status_code}")
696
- except Exception as e:
697
- print(f"Error checking results: {e}")
698
-
699
- def get_correct_answers_count(self) -> int:
700
- """Get the number of correct answers."""
701
- return self.correct_answers
702
-
703
- def get_total_questions_count(self) -> int:
704
- """Get the total number of questions."""
705
- return self.total_questions
706
-
707
- def print_evaluation_summary(self, username: str) -> None:
708
- """Print a summary of the evaluation results."""
709
- print("\n===== EVALUATION SUMMARY =====")
710
- print(f"User: {username}")
711
- print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
712
- print(f"Correct Answers: {self.correct_answers}")
713
- print(f"Total Questions: {self.total_questions}")
714
- print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
715
- print("=============================\n")
716
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
 
718
- # Example usage and test cases
719
  def test_agent():
720
- """Test the agent with example questions."""
721
- agent = EnhancedGAIAAgent()
722
-
723
  test_questions = [
724
- # Calculation questions
725
- "What is 25 + 17?",
726
- "Calculate the product of 8 and 9",
727
-
728
- # Date/time questions
729
- "What is today's date?",
730
- "What day of the week is it?",
731
-
732
- # List questions
733
- "List five fruits",
734
- "What are the planets in our solar system?",
735
-
736
- # Visual questions
737
- "What does the image show?",
738
- "Describe the chart in the image",
739
-
740
- # Factual questions
741
- "Who was the first president of the United States?",
742
- "What is the capital of France?",
743
- "How does photosynthesis work?",
744
-
745
- # General questions
746
- "Why is the sky blue?",
747
- "What are the implications of quantum mechanics?"
748
  ]
749
-
750
- print("\n=== AGENT TEST RESULTS ===")
751
- correct_count = 0
752
- total_count = len(test_questions)
753
-
754
  for question in test_questions:
755
- # Generate a mock task_id for testing
756
- task_id = f"test_{hash(question) % 10000}"
757
-
758
- # Get formatted JSON response
759
- json_response = agent(question, task_id)
760
-
761
- print(f"\nQ: {question}")
762
- print(f"Response: {json_response}")
763
-
764
- # Parse and print the model_answer for clarity
765
- try:
766
- response_obj = json.loads(json_response)
767
- model_answer = response_obj.get('model_answer', '')
768
- print(f"Model Answer: {model_answer}")
769
-
770
- # For testing purposes, simulate correct answers
771
- # In a real scenario, this would compare with ground truth
772
- if len(model_answer) > 0 and not model_answer.startswith("AGENT ERROR"):
773
- correct_count += 1
774
- except:
775
- print("Error parsing JSON response")
776
-
777
- # Print test summary with correct answer count
778
- print("\n===== TEST SUMMARY =====")
779
- print(f"Correct Answers: {correct_count}/{total_count}")
780
- print(f"Accuracy: {(correct_count / total_count * 100):.1f}%")
781
- print("=======================\n")
782
-
783
- return "Test completed successfully"
784
-
785
 
786
  if __name__ == "__main__":
787
  test_agent()
 
 
1
  """
2
+ Улучшенный агент GAIA с интеграцией LLM для курса Hugging Face
3
  """
4
 
5
  import os
6
+ import gradio as gr
 
 
 
7
  import requests
8
+ import pandas as pd
9
+ import json
10
+ import time
11
+ from typing import List, Dict, Any, Optional, Callable, Union
12
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
13
 
14
+ # --- Константы ---
15
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
+ DEFAULT_MODEL = "google/flan-t5-small" # Меньшая модель для быстрой загрузки
17
+ MAX_RETRIES = 3 # Максимальное количество попыток отправки
18
+ RETRY_DELAY = 5 # Задержка между попытками в секундах
19
+
20
+ class LLMGAIAAgent:
21
  """
22
+ Улучшенный агент GAIA, использующий языковую модель для генерации ответов.
 
23
  """
24
 
25
+ def __init__(self, model_name=DEFAULT_MODEL):
26
+ """Инициализация агента с языковой моделью."""
27
+ print(f"Инициализация LLMGAIAAgent с моделью: {model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
31
+ self.model_name = model_name
32
+ print(f"Успешно загружена модель: {model_name}")
 
33
  except Exception as e:
34
+ print(f"Ошибка загрузки модели: {e}")
35
+ print("Переход к шаблонным ответам")
 
36
  self.model = None
37
+ self.tokenizer = None
38
+ self.model_name = None
39
 
40
+ def __call__(self, question: str) -> str:
41
+ """Обработка вопроса и возврат ответа с использованием языковой модели."""
42
+ print(f"Обработка вопроса: {question}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ if self.model is None or self.tokenizer is None:
45
+ return self._fallback_response(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  try:
48
+ prompt = self._prepare_prompt(question)
49
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
 
 
 
 
50
  outputs = self.model.generate(
51
  inputs["input_ids"],
52
+ max_length=150,
53
+ min_length=20,
54
+ temperature=0.7,
55
+ top_p=0.9,
56
  do_sample=True,
57
  num_return_sequences=1
58
  )
 
 
59
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+ response = self._clean_response(response)
 
 
 
61
  return response
62
  except Exception as e:
63
+ print(f"Ошибка генерации ответа: {e}")
64
+ return self._fallback_response(question)
65
 
66
+ def _prepare_prompt(self, question: str) -> str:
67
+ """Подготовка подходящего запроса на основе типа вопроса."""
68
+ question_lower = question.lower()
69
+ if any(keyword in question_lower for keyword in [
70
+ "calculate", "compute", "sum", "difference",
71
+ "product", "divide", "plus", "minus", "times"
72
+ ]):
73
+ return f"Решите эту математическую задачу шаг за шагом: {question}"
74
+ elif any(keyword in question_lower for keyword in [
75
+ "image", "picture", "photo", "graph", "chart", "diagram"
76
+ ]):
77
+ return f"Опишите, что может быть изображено на картинке, связанной с этим вопросом: {question}"
78
+ elif any(keyword in question_lower for keyword in [
79
+ "who", "what", "where", "when", "why", "how"
80
+ ]):
81
+ return f"Дайте краткий и точный ответ на этот фактический вопрос: {question}"
82
+ else:
83
+ return f"Дайте краткий, информативный ответ на этот вопрос: {question}"
84
+
85
+ def _clean_response(self, response: str) -> str:
86
+ """Очистка ответа модели для получения чистого текста."""
87
+ prefixes = [
88
+ "Answer:", "Response:", "A:", "The answer is:",
89
+ "It is:", "I think it is:", "The result is:",
90
+ "Based on the image:", "In the image:",
91
+ "The image shows:", "From the image:"
92
+ ]
93
  for prefix in prefixes:
94
  if response.lower().startswith(prefix.lower()):
95
  response = response[len(prefix):].strip()
96
+ if len(response) < 10:
97
+ return self._fallback_response("general")
 
 
 
 
 
 
 
 
 
 
 
98
  return response.strip()
99
 
100
+ def _fallback_response(self, question: str) -> str:
101
+ """Резервный ответ, если модель не сработала."""
102
+ question_lower = question.lower() if isinstance(question, str) else ""
103
+ if "who" in question_lower:
104
+ return "Известная личность в этой области."
105
+ elif "when" in question_lower:
106
+ return "Это произошло в значительный исторический период."
107
+ elif "where" in question_lower:
108
+ return "Место известно своей культурной значимостью."
109
+ elif "what" in question_lower:
110
+ return "Это важное понятие или объект."
111
+ elif "why" in question_lower:
112
+ return "Это произошло из-за ряда факторов."
113
+ elif "how" in question_lower:
114
+ return "Процесс включает несколько ключевых шагов."
115
+ return "Ответ включает несколько важных факторов."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  class EvaluationRunner:
118
  """
119
+ Управление процессом оценки: получение вопросов, запуск агента и отправка ответов.
 
120
  """
121
 
122
+ def __init__(self, api_url: str = DEFAULT_API_URL):
123
+ """Инициализация с конечными точками API."""
124
  self.api_url = api_url
125
  self.questions_url = f"{api_url}/questions"
126
  self.submit_url = f"{api_url}/submit"
 
 
 
 
 
 
127
 
128
  def run_evaluation(self,
129
+ agent: Callable[[str], str],
130
  username: str,
131
+ agent_code_url: str) -> tuple[str, pd.DataFrame]:
132
+ """Запуск полного процесса оценки."""
 
 
 
 
 
 
 
 
 
 
 
 
133
  questions_data = self._fetch_questions()
134
+ if isinstance(questions_data, str):
135
  return questions_data, None
136
 
 
137
  results_log, answers_payload = self._run_agent_on_questions(agent, questions_data)
138
  if not answers_payload:
139
+ return "Агент не дал ответов для отправки.", pd.DataFrame(results_log)
 
 
 
140
 
141
+ submission_result = self._submit_answers_with_retry(username, agent_code_url, answers_payload)
142
+ return submission_result, pd.DataFrame(results_log)
 
 
 
143
 
144
  def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
145
+ """Получение вопросов с сервера оценки."""
146
+ print(f"Получение вопросов с: {self.questions_url}")
147
  try:
148
  response = requests.get(self.questions_url, timeout=15)
149
  response.raise_for_status()
150
  questions_data = response.json()
 
151
  if not questions_data:
152
+ return "Список вопросов пуст или некорректен."
153
+ print(f"Успешно получено {len(questions_data)} вопросов.")
 
 
 
 
154
  return questions_data
 
 
 
 
 
 
 
 
 
 
 
 
155
  except Exception as e:
156
+ return f"Ошибка получения вопросов: {e}"
 
 
157
 
158
  def _run_agent_on_questions(self,
159
+ agent: Callable[[str], str],
160
  questions_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
161
+ """Запуск агента на всех вопросах."""
162
  results_log = []
163
  answers_payload = []
164
+ print(f"Запуск агента на {len(questions_data)} вопросах...")
 
165
  for item in questions_data:
166
  task_id = item.get("task_id")
167
  question_text = item.get("question")
 
168
  if not task_id or question_text is None:
 
169
  continue
 
170
  try:
171
+ submitted_answer = agent(question_text)
172
+ answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
173
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  except Exception as e:
175
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"ОШИБКА: {e}"})
 
 
 
 
 
 
176
  return results_log, answers_payload
177
 
178
+ def _submit_answers_with_retry(self,
179
+ username: str,
180
+ agent_code_url: str,
181
+ answers_payload: List[Dict[str, Any]]) -> str:
182
+ """Отправка ответов с логикой повтора."""
183
  submission_data = {
184
  "username": username.strip(),
185
+ "agent_code_url": agent_code_url, # Исправленный ключ
186
  "answers": answers_payload
187
  }
188
+ print(f"Отправка {len(answers_payload)} ответов для пользователя '{username}'...")
189
+ for attempt in range(1, MAX_RETRIES + 1):
 
 
 
 
190
  try:
191
+ print(f"Попытка {attempt} из {MAX_RETRIES}...")
192
+ response = requests.post(self.submit_url, json=submission_data, timeout=60)
 
 
 
 
 
193
  response.raise_for_status()
194
+ result_data = response.json()
195
+ final_status = (
196
+ f"Отправка успешна!\n"
197
+ f"Пользователь: {result_data.get('username')}\n"
198
+ f"Общий балл: {result_data.get('overall_score', 'N/A')}\n"
199
+ f"Правильные ответы: {result_data.get('correct_answers', 'N/A')}\n"
200
+ f"Всего вопросов: {result_data.get('total_questions', 'N/A')}\n"
201
+ )
202
+ if all(result_data.get(key, "N/A") == "N/A" for key in ["overall_score", "correct_answers", "total_questions"]):
203
+ final_status += (
204
+ "\nПримечание: Результаты показывают 'N/A'. Возможные причины:\n"
205
+ "- Ограничения активности аккаунта\n"
206
+ "- Задержка обработки\n"
207
+ "- Проблема с API\n"
208
+ f"Проверьте статус: {DEFAULT_API_URL}/results?username={username}"
209
+ )
210
+ print(final_status)
211
+ return final_status
212
+ except Exception as e:
213
+ if attempt < MAX_RETRIES:
214
+ time.sleep(RETRY_DELAY)
 
 
 
 
 
 
215
  else:
216
+ return f"Ошибка отправки после {MAX_RETRIES} попыток: {e}"
217
+
218
+ def run_and_submit_all(profile: gr.OAuthProfile | None, *args):
219
+ """Основная функция для запуска через Gradio."""
220
+ if not profile:
221
+ return "Пожалуйста, войдите в Hugging Face.", None
222
+ username = profile.username
223
+ space_id = os.getenv("SPACE_ID")
224
+ agent_code_url = f"https://huggingface.co/spaces/{space_id}/tree/main"
225
+ print(f"URL кода агента: {agent_code_url}")
226
+ try:
227
+ agent = LLMGAIAAgent()
228
+ runner = EvaluationRunner()
229
+ return runner.run_evaluation(agent, username, agent_code_url)
230
+ except Exception as e:
231
+ return f"Ошибка инициализации: {e}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ # --- Интерфейс Gradio ---
234
+ with gr.Blocks() as demo:
235
+ gr.Markdown("# Оценка агента GAIA (с улучшенным LLM)")
236
+ gr.Markdown("## Инструкции:")
237
+ gr.Markdown("1. Войдите в аккаунт Hugging Face.")
238
+ gr.Markdown("2. Нажмите 'Запустить оценку и отправить все ответы'.")
239
+ gr.Markdown("3. Посмотрите результаты в разделе вывода.")
240
+ with gr.Row():
241
+ login_button = gr.LoginButton(value="Войти через Hugging Face")
242
+ with gr.Row():
243
+ submit_button = gr.Button("Запустить оценку и отправить все ответы")
244
+ with gr.Row():
245
+ output_status = gr.Textbox(label="Результат отправки", lines=10)
246
+ output_results = gr.Dataframe(label="Вопросы и ответы агента")
247
+ submit_button.click(run_and_submit_all, inputs=[login_button], outputs=[output_status, output_results])
248
 
249
+ # --- Локальная тестовая функция ---
250
  def test_agent():
251
+ """Тестирование агента с примерами вопросов."""
252
+ agent = LLMGAIAAgent()
 
253
  test_questions = [
254
+ "What is 2 + 2?",
255
+ "Who is the first president of the USA?",
256
+ "What is the capital of France?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  ]
 
 
 
 
 
258
  for question in test_questions:
259
+ answer = agent(question)
260
+ print(f"Вопрос: {question}")
261
+ print(f"Ответ: {answer}")
262
+ print("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  if __name__ == "__main__":
265
  test_agent()
266
+ # demo.launch()