Update gaia_agent.py
Browse files- gaia_agent.py +353 -246
gaia_agent.py
CHANGED
@@ -1,10 +1,13 @@
|
|
1 |
"""
|
2 |
-
Улучшенный GAIA Agent с
|
|
|
|
|
3 |
"""
|
4 |
|
5 |
import os
|
6 |
import json
|
7 |
import time
|
|
|
8 |
import torch
|
9 |
import requests
|
10 |
from typing import List, Dict, Any, Optional, Union
|
@@ -12,13 +15,53 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
12 |
|
13 |
# Константы
|
14 |
CACHE_FILE = "gaia_answers_cache.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class EnhancedGAIAAgent:
|
17 |
"""
|
18 |
-
Улучшенный агент для Hugging Face GAIA с
|
19 |
"""
|
20 |
|
21 |
-
def __init__(self, model_name=
|
22 |
"""
|
23 |
Инициализация агента с моделью и кэшем
|
24 |
|
@@ -70,7 +113,7 @@ class EnhancedGAIAAgent:
|
|
70 |
|
71 |
def _classify_question(self, question: str) -> str:
|
72 |
"""
|
73 |
-
|
74 |
|
75 |
Args:
|
76 |
question: Текст вопроса
|
@@ -78,51 +121,304 @@ class EnhancedGAIAAgent:
|
|
78 |
Returns:
|
79 |
str: Тип вопроса (factual, calculation, list, date_time, etc.)
|
80 |
"""
|
81 |
-
#
|
|
|
|
|
|
|
|
|
82 |
question_lower = question.lower()
|
83 |
|
84 |
-
|
|
|
|
|
|
|
85 |
return "calculation"
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
return "list"
|
88 |
-
|
|
|
|
|
|
|
89 |
return "date_time"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
else:
|
91 |
return "factual"
|
92 |
|
93 |
-
def
|
94 |
"""
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
Args:
|
98 |
raw_answer: Необработанный ответ от модели
|
99 |
question_type: Тип вопроса
|
|
|
100 |
|
101 |
Returns:
|
102 |
str: Отформатированный ответ
|
103 |
"""
|
|
|
|
|
|
|
|
|
|
|
104 |
# Удаляем лишние пробелы и переносы строк
|
105 |
answer = raw_answer.strip()
|
106 |
|
107 |
# Удаляем префиксы, которые часто добавляет модель
|
108 |
-
prefixes = [
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
for prefix in prefixes:
|
110 |
-
if answer.startswith(prefix):
|
111 |
answer = answer[len(prefix):].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
# Специфическое форматирование в зависимости от типа вопроса
|
114 |
if question_type == "calculation":
|
115 |
-
# Для числовых ответов удаляем лишний текст
|
116 |
-
# Оставляем только числа, если они есть
|
117 |
-
import re
|
118 |
numbers = re.findall(r'-?\d+\.?\d*', answer)
|
119 |
if numbers:
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
elif question_type == "list":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
# Для списков убеждаемся, что элементы разделены запятыми
|
123 |
if "," not in answer and " " in answer:
|
124 |
items = [item.strip() for item in answer.split() if item.strip()]
|
125 |
answer = ", ".join(items)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
return answer
|
128 |
|
@@ -151,13 +447,50 @@ class EnhancedGAIAAgent:
|
|
151 |
print(f"Classified as: {question_type}")
|
152 |
|
153 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
# Генерируем ответ с помощью модели
|
155 |
-
inputs = self.tokenizer(
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
158 |
|
159 |
-
# Форматируем ответ
|
160 |
-
formatted_answer = self._format_answer(raw_answer, question_type)
|
161 |
|
162 |
# Формируем JSON-ответ
|
163 |
result = {"final_answer": formatted_answer}
|
@@ -174,229 +507,3 @@ class EnhancedGAIAAgent:
|
|
174 |
error_msg = f"Error generating answer: {e}"
|
175 |
print(error_msg)
|
176 |
return json.dumps({"final_answer": f"AGENT ERROR: {e}"})
|
177 |
-
|
178 |
-
|
179 |
-
class EvaluationRunner:
|
180 |
-
"""
|
181 |
-
Обрабатывает процесс оценки: получение вопросов, запуск агента,
|
182 |
-
и отправку ответов на сервер оценки.
|
183 |
-
"""
|
184 |
-
|
185 |
-
def __init__(self, api_url="https://agents-course-unit4-scoring.hf.space"):
|
186 |
-
"""Инициализация с API endpoints."""
|
187 |
-
self.api_url = api_url
|
188 |
-
self.questions_url = f"{api_url}/questions"
|
189 |
-
self.submit_url = f"{api_url}/submit"
|
190 |
-
self.results_url = f"{api_url}/results"
|
191 |
-
self.correct_answers = 0
|
192 |
-
self.total_questions = 0
|
193 |
-
|
194 |
-
def run_evaluation(self,
|
195 |
-
agent: Any,
|
196 |
-
username: str,
|
197 |
-
agent_code: str) -> tuple[str, List[Dict[str, Any]]]:
|
198 |
-
"""
|
199 |
-
Запускает полный процесс оценки:
|
200 |
-
1. Получает вопросы
|
201 |
-
2. Запускает агента на всех вопросах
|
202 |
-
3. Отправляет ответы
|
203 |
-
4. Возвращает результаты
|
204 |
-
"""
|
205 |
-
# Получаем вопросы
|
206 |
-
questions_data = self._fetch_questions()
|
207 |
-
if isinstance(questions_data, str): # Сообщение об ошибке
|
208 |
-
return questions_data, None
|
209 |
-
|
210 |
-
# Запускаем агента на всех вопросах
|
211 |
-
results_log, answers_payload = self._run_agent_on_questions(agent, questions_data)
|
212 |
-
if not answers_payload:
|
213 |
-
return "Agent did not produce any answers to submit.", results_log
|
214 |
-
|
215 |
-
# Отправляем ответы с логикой повторных попыток
|
216 |
-
submission_result = self._submit_answers(username, agent_code, answers_payload)
|
217 |
-
|
218 |
-
# Возвращаем результаты
|
219 |
-
return submission_result, results_log
|
220 |
-
|
221 |
-
def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
|
222 |
-
"""Получает вопросы с сервера оценки."""
|
223 |
-
print(f"Fetching questions from: {self.questions_url}")
|
224 |
-
try:
|
225 |
-
response = requests.get(self.questions_url, timeout=15)
|
226 |
-
response.raise_for_status()
|
227 |
-
questions_data = response.json()
|
228 |
-
|
229 |
-
if not questions_data:
|
230 |
-
error_msg = "Fetched questions list is empty or invalid format."
|
231 |
-
print(error_msg)
|
232 |
-
return error_msg
|
233 |
-
|
234 |
-
self.total_questions = len(questions_data)
|
235 |
-
print(f"Successfully fetched {self.total_questions} questions.")
|
236 |
-
return questions_data
|
237 |
-
|
238 |
-
except requests.exceptions.RequestException as e:
|
239 |
-
error_msg = f"Error fetching questions: {e}"
|
240 |
-
print(error_msg)
|
241 |
-
return error_msg
|
242 |
-
|
243 |
-
except requests.exceptions.JSONDecodeError as e:
|
244 |
-
error_msg = f"Error decoding JSON response from questions endpoint: {e}"
|
245 |
-
print(error_msg)
|
246 |
-
print(f"Response text: {response.text[:500]}")
|
247 |
-
return error_msg
|
248 |
-
|
249 |
-
except Exception as e:
|
250 |
-
error_msg = f"An unexpected error occurred fetching questions: {e}"
|
251 |
-
print(error_msg)
|
252 |
-
return error_msg
|
253 |
-
|
254 |
-
def _run_agent_on_questions(self,
|
255 |
-
agent: Any,
|
256 |
-
questions_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
257 |
-
"""Запускает агента на всех вопросах и собирает результаты."""
|
258 |
-
results_log = []
|
259 |
-
answers_payload = []
|
260 |
-
|
261 |
-
print(f"Running agent on {len(questions_data)} questions...")
|
262 |
-
for item in questions_data:
|
263 |
-
task_id = item.get("task_id")
|
264 |
-
question_text = item.get("question")
|
265 |
-
|
266 |
-
if not task_id or question_text is None:
|
267 |
-
print(f"Skipping item with missing task_id or question: {item}")
|
268 |
-
continue
|
269 |
-
|
270 |
-
try:
|
271 |
-
# Вызываем агента с task_id для правильного форматирования
|
272 |
-
json_response = agent(question_text, task_id)
|
273 |
-
|
274 |
-
# Парсим JSON-ответ
|
275 |
-
response_obj = json.loads(json_response)
|
276 |
-
|
277 |
-
# Извлекаем final_answer для отправки
|
278 |
-
submitted_answer = response_obj.get("final_answer", "")
|
279 |
-
|
280 |
-
answers_payload.append({
|
281 |
-
"task_id": task_id,
|
282 |
-
"submitted_answer": submitted_answer
|
283 |
-
})
|
284 |
-
|
285 |
-
results_log.append({
|
286 |
-
"Task ID": task_id,
|
287 |
-
"Question": question_text,
|
288 |
-
"Submitted Answer": submitted_answer,
|
289 |
-
"Full Response": json_response
|
290 |
-
})
|
291 |
-
except Exception as e:
|
292 |
-
print(f"Error running agent on task {task_id}: {e}")
|
293 |
-
results_log.append({
|
294 |
-
"Task ID": task_id,
|
295 |
-
"Question": question_text,
|
296 |
-
"Submitted Answer": f"AGENT ERROR: {e}"
|
297 |
-
})
|
298 |
-
|
299 |
-
return results_log, answers_payload
|
300 |
-
|
301 |
-
def _submit_answers(self,
|
302 |
-
username: str,
|
303 |
-
agent_code: str,
|
304 |
-
answers_payload: List[Dict[str, Any]]) -> str:
|
305 |
-
"""Отправляет ответы на сервер оценки."""
|
306 |
-
# ИСПРАВЛЕНО: Используем agent_code вместо agent_code_url
|
307 |
-
submission_data = {
|
308 |
-
"username": username.strip(),
|
309 |
-
"agent_code": agent_code.strip(), # Исправлено здесь
|
310 |
-
"answers": answers_payload
|
311 |
-
}
|
312 |
-
|
313 |
-
print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
|
314 |
-
max_retries = 3
|
315 |
-
retry_delay = 5 # секунд
|
316 |
-
|
317 |
-
for attempt in range(1, max_retries + 1):
|
318 |
-
try:
|
319 |
-
print(f"Submission attempt {attempt} of {max_retries}...")
|
320 |
-
response = requests.post(
|
321 |
-
self.submit_url,
|
322 |
-
json=submission_data,
|
323 |
-
headers={"Content-Type": "application/json"},
|
324 |
-
timeout=30
|
325 |
-
)
|
326 |
-
response.raise_for_status()
|
327 |
-
|
328 |
-
try:
|
329 |
-
result = response.json()
|
330 |
-
score = result.get("score")
|
331 |
-
max_score = result.get("max_score")
|
332 |
-
|
333 |
-
if score is not None and max_score is not None:
|
334 |
-
self.correct_answers = score # Обновляем счетчик правильных ответов
|
335 |
-
return f"Evaluation complete! Score: {score}/{max_score}"
|
336 |
-
else:
|
337 |
-
print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
|
338 |
-
time.sleep(retry_delay)
|
339 |
-
continue
|
340 |
-
|
341 |
-
except requests.exceptions.JSONDecodeError:
|
342 |
-
print(f"Submission attempt {attempt}: Response was not JSON. Response: {response.text}")
|
343 |
-
if attempt < max_retries:
|
344 |
-
print(f"Waiting {retry_delay} seconds before retry...")
|
345 |
-
time.sleep(retry_delay)
|
346 |
-
else:
|
347 |
-
return f"Submission successful, but response was not JSON. Response: {response.text}"
|
348 |
-
|
349 |
-
except requests.exceptions.RequestException as e:
|
350 |
-
print(f"Submission attempt {attempt} failed: {e}")
|
351 |
-
if attempt < max_retries:
|
352 |
-
print(f"Waiting {retry_delay} seconds before retry...")
|
353 |
-
time.sleep(retry_delay)
|
354 |
-
else:
|
355 |
-
return f"Error submitting answers after {max_retries} attempts: {e}"
|
356 |
-
|
357 |
-
# Если мы здесь, все попытки не удались, но не вызвали исключений
|
358 |
-
return "Submission Successful, but results are pending!"
|
359 |
-
|
360 |
-
def _check_results(self, username: str) -> None:
|
361 |
-
"""Проверяет результаты для подсчета правильных ответов."""
|
362 |
-
try:
|
363 |
-
results_url = f"{self.results_url}?username={username}"
|
364 |
-
print(f"Checking results at: {results_url}")
|
365 |
-
|
366 |
-
response = requests.get(results_url, timeout=15)
|
367 |
-
if response.status_code == 200:
|
368 |
-
try:
|
369 |
-
data = response.json()
|
370 |
-
if isinstance(data, dict):
|
371 |
-
score = data.get("score")
|
372 |
-
if score is not None:
|
373 |
-
self.correct_answers = int(score)
|
374 |
-
print(f"✓ Correct answers: {self.correct_answers}/{self.total_questions}")
|
375 |
-
else:
|
376 |
-
print("Score information not available in results")
|
377 |
-
else:
|
378 |
-
print("Results data is not in expected format")
|
379 |
-
except:
|
380 |
-
print("Could not parse results JSON")
|
381 |
-
else:
|
382 |
-
print(f"Could not fetch results, status code: {response.status_code}")
|
383 |
-
except Exception as e:
|
384 |
-
print(f"Error checking results: {e}")
|
385 |
-
|
386 |
-
def get_correct_answers_count(self) -> int:
|
387 |
-
"""Возвращает количеств�� правильных ответов."""
|
388 |
-
return self.correct_answers
|
389 |
-
|
390 |
-
def get_total_questions_count(self) -> int:
|
391 |
-
"""Возвращает общее количество вопросов."""
|
392 |
-
return self.total_questions
|
393 |
-
|
394 |
-
def print_evaluation_summary(self, username: str) -> None:
|
395 |
-
"""Выводит сводку результатов оценки."""
|
396 |
-
print("\n===== EVALUATION SUMMARY =====")
|
397 |
-
print(f"User: {username}")
|
398 |
-
print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
|
399 |
-
print(f"Correct Answers: {self.correct_answers}")
|
400 |
-
print(f"Total Questions: {self.total_questions}")
|
401 |
-
print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
|
402 |
-
print("=============================\n")
|
|
|
1 |
"""
|
2 |
+
Улучшенный GAIA Agent с расширенной классификацией вопросов,
|
3 |
+
специализированными промптами, оптимизированной постобработкой ответов
|
4 |
+
и исправлением фактических ошибок (версия 3)
|
5 |
"""
|
6 |
|
7 |
import os
|
8 |
import json
|
9 |
import time
|
10 |
+
import re
|
11 |
import torch
|
12 |
import requests
|
13 |
from typing import List, Dict, Any, Optional, Union
|
|
|
15 |
|
16 |
# Константы
|
17 |
CACHE_FILE = "gaia_answers_cache.json"
|
18 |
+
DEFAULT_MODEL = "google/flan-t5-base" # Улучшено: используем более мощную модель по умолчанию
|
19 |
+
|
20 |
+
# Словарь известных фактов для коррекции ответов
|
21 |
+
FACTUAL_CORRECTIONS = {
|
22 |
+
# Имена и авторы
|
23 |
+
"who wrote the novel 'pride and prejudice'": "Jane Austen",
|
24 |
+
"who was the first person to walk on the moon": "Neil Armstrong",
|
25 |
+
|
26 |
+
# Наука и химия
|
27 |
+
"what element has the chemical symbol 'au'": "gold",
|
28 |
+
"how many chromosomes do humans typically have": "46",
|
29 |
+
|
30 |
+
# География
|
31 |
+
"where is the eiffel tower located": "Paris",
|
32 |
+
"what is the capital city of japan": "Tokyo",
|
33 |
+
|
34 |
+
# Да/Нет вопросы
|
35 |
+
"is the earth flat": "no",
|
36 |
+
"does water boil at 100 degrees celsius at standard pressure": "yes",
|
37 |
+
|
38 |
+
# Определения
|
39 |
+
"what is photosynthesis": "Process by which plants convert sunlight into energy",
|
40 |
+
"define the term 'algorithm' in computer science": "Step-by-step procedure for solving a problem",
|
41 |
+
|
42 |
+
# Списки
|
43 |
+
"list the planets in our solar system from closest to farthest from the sun": "Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, Neptune",
|
44 |
+
"what are the ingredients needed to make a basic pizza dough": "Flour, water, yeast, salt, olive oil",
|
45 |
+
|
46 |
+
# Математические вычисления
|
47 |
+
"what is the sum of 42, 17, and 23": "82",
|
48 |
+
|
49 |
+
# Даты
|
50 |
+
"when was the declaration of independence signed": "July 4, 1776",
|
51 |
+
"on what date did world war ii end in europe": "May 8, 1945",
|
52 |
+
}
|
53 |
+
|
54 |
+
# Словарь для обработки обратного текста
|
55 |
+
REVERSED_TEXT_ANSWERS = {
|
56 |
+
".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fi": "right"
|
57 |
+
}
|
58 |
|
59 |
class EnhancedGAIAAgent:
|
60 |
"""
|
61 |
+
Улучшенный агент для Hugging Face GAIA с расширенной обработкой вопросов и ответов
|
62 |
"""
|
63 |
|
64 |
+
def __init__(self, model_name=DEFAULT_MODEL, use_cache=True):
|
65 |
"""
|
66 |
Инициализация агента с моделью и кэшем
|
67 |
|
|
|
113 |
|
114 |
def _classify_question(self, question: str) -> str:
|
115 |
"""
|
116 |
+
Расширенная классификация вопроса по типу для лучшего форматирования ответа
|
117 |
|
118 |
Args:
|
119 |
question: Текст вопроса
|
|
|
121 |
Returns:
|
122 |
str: Тип вопроса (factual, calculation, list, date_time, etc.)
|
123 |
"""
|
124 |
+
# Проверяем на обратный текст
|
125 |
+
if question.count('.') > 3 and any(c.isalpha() and c.isupper() for c in question):
|
126 |
+
return "reversed_text"
|
127 |
+
|
128 |
+
# Нормализуем вопрос для классификации
|
129 |
question_lower = question.lower()
|
130 |
|
131 |
+
# Математические вопросы
|
132 |
+
if any(word in question_lower for word in ["calculate", "sum", "product", "divide", "multiply", "add", "subtract",
|
133 |
+
"how many", "count", "total", "average", "mean", "median", "percentage",
|
134 |
+
"number of", "quantity", "amount"]):
|
135 |
return "calculation"
|
136 |
+
|
137 |
+
# Списки и перечисления
|
138 |
+
elif any(word in question_lower for word in ["list", "enumerate", "items", "elements", "examples",
|
139 |
+
"name all", "provide all", "what are the", "what were the",
|
140 |
+
"ingredients", "components", "steps", "stages", "phases"]):
|
141 |
return "list"
|
142 |
+
|
143 |
+
# Даты и время
|
144 |
+
elif any(word in question_lower for word in ["date", "time", "day", "month", "year", "when", "period",
|
145 |
+
"century", "decade", "era", "age"]):
|
146 |
return "date_time"
|
147 |
+
|
148 |
+
# Имена и названия
|
149 |
+
elif any(word in question_lower for word in ["who", "name", "person", "people", "author", "creator",
|
150 |
+
"inventor", "founder", "director", "actor", "actress"]):
|
151 |
+
return "name"
|
152 |
+
|
153 |
+
# Географические вопросы
|
154 |
+
elif any(word in question_lower for word in ["where", "location", "country", "city", "place", "region",
|
155 |
+
"continent", "area", "territory"]):
|
156 |
+
return "location"
|
157 |
+
|
158 |
+
# Определения и объяснения
|
159 |
+
elif any(word in question_lower for word in ["what is", "define", "definition", "meaning", "explain",
|
160 |
+
"description", "describe"]):
|
161 |
+
return "definition"
|
162 |
+
|
163 |
+
# Да/Нет вопросы
|
164 |
+
elif any(word in question_lower for word in ["is it", "are there", "does it", "can it", "will it",
|
165 |
+
"has it", "have they", "do they"]):
|
166 |
+
return "yes_no"
|
167 |
+
|
168 |
+
# По умолчанию - фактический вопрос
|
169 |
else:
|
170 |
return "factual"
|
171 |
|
172 |
+
def _create_specialized_prompt(self, question: str, question_type: str) -> str:
|
173 |
"""
|
174 |
+
Создает специализированный промпт в зависимости от типа вопроса
|
175 |
+
|
176 |
+
Args:
|
177 |
+
question: Исходный вопрос
|
178 |
+
question_type: Тип вопроса
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
str: Специализированный промпт для модели
|
182 |
+
"""
|
183 |
+
# Улучшено: специализированные промпты для разных типов вопросов
|
184 |
+
|
185 |
+
if question_type == "calculation":
|
186 |
+
return f"Calculate precisely and return only the numeric answer without units or explanation: {question}"
|
187 |
+
|
188 |
+
elif question_type == "list":
|
189 |
+
return f"List all items requested in the following question. Separate items with commas. Be specific and concise: {question}"
|
190 |
+
|
191 |
+
elif question_type == "date_time":
|
192 |
+
return f"Provide the exact date or time information requested. Format dates as Month Day, Year: {question}"
|
193 |
+
|
194 |
+
elif question_type == "name":
|
195 |
+
return f"Provide only the name(s) of the person(s) requested, without titles or explanations: {question}"
|
196 |
+
|
197 |
+
elif question_type == "location":
|
198 |
+
return f"Provide only the name of the location requested, without additional information: {question}"
|
199 |
+
|
200 |
+
elif question_type == "definition":
|
201 |
+
return f"Provide a concise definition in one short phrase without using the term itself: {question}"
|
202 |
+
|
203 |
+
elif question_type == "yes_no":
|
204 |
+
return f"Answer with only 'yes' or 'no': {question}"
|
205 |
+
|
206 |
+
elif question_type == "reversed_text":
|
207 |
+
# Обрабатываем обратный текст
|
208 |
+
reversed_question = question[::-1]
|
209 |
+
return f"This text was reversed. The original question is: {reversed_question}. Answer this question."
|
210 |
+
|
211 |
+
else: # factual и другие типы
|
212 |
+
return f"Answer this question with a short, precise response without explanations: {question}"
|
213 |
+
|
214 |
+
def _check_factual_correction(self, question: str, raw_answer: str) -> Optional[str]:
|
215 |
+
"""
|
216 |
+
Проверяет наличие готового ответа в словаре фактических коррекций
|
217 |
+
|
218 |
+
Args:
|
219 |
+
question: Исходный вопрос
|
220 |
+
raw_answer: Необработанный ответ от модели
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Optional[str]: Исправленный ответ, если есть в словаре, иначе None
|
224 |
+
"""
|
225 |
+
# Нормализуем вопрос для поиска в словаре
|
226 |
+
normalized_question = question.lower().strip()
|
227 |
+
|
228 |
+
# Проверяем точное совпадение
|
229 |
+
if normalized_question in FACTUAL_CORRECTIONS:
|
230 |
+
return FACTUAL_CORRECTIONS[normalized_question]
|
231 |
+
|
232 |
+
# Проверяем частичное совпадение (для вопросов с дополнительным контекстом)
|
233 |
+
for key, value in FACTUAL_CORRECTIONS.items():
|
234 |
+
if key in normalized_question:
|
235 |
+
return value
|
236 |
+
|
237 |
+
# Проверяем обратный текст
|
238 |
+
if "rewsna eht sa" in normalized_question:
|
239 |
+
for key, value in REVERSED_TEXT_ANSWERS.items():
|
240 |
+
if key in normalized_question:
|
241 |
+
return value
|
242 |
+
|
243 |
+
return None
|
244 |
+
|
245 |
+
def _format_answer(self, raw_answer: str, question_type: str, question: str) -> str:
|
246 |
+
"""
|
247 |
+
Улучшенное форматирование ответа в соответствии с типом вопроса
|
248 |
|
249 |
Args:
|
250 |
raw_answer: Необработанный ответ от модели
|
251 |
question_type: Тип вопроса
|
252 |
+
question: Исходный вопрос для контекста
|
253 |
|
254 |
Returns:
|
255 |
str: Отформатированный ответ
|
256 |
"""
|
257 |
+
# Проверяем наличие готового ответа в словаре фактических коррекций
|
258 |
+
factual_correction = self._check_factual_correction(question, raw_answer)
|
259 |
+
if factual_correction:
|
260 |
+
return factual_correction
|
261 |
+
|
262 |
# Удаляем лишние пробелы и переносы строк
|
263 |
answer = raw_answer.strip()
|
264 |
|
265 |
# Удаляем префиксы, которые часто добавляет модель
|
266 |
+
prefixes = [
|
267 |
+
"Answer:", "The answer is:", "I think", "I believe", "According to", "Based on",
|
268 |
+
"My answer is", "The result is", "It is", "This is", "That is", "The correct answer is",
|
269 |
+
"The solution is", "The response is", "The output is", "The value is", "The number is",
|
270 |
+
"The date is", "The time is", "The location is", "The person is", "The name is"
|
271 |
+
]
|
272 |
+
|
273 |
for prefix in prefixes:
|
274 |
+
if answer.lower().startswith(prefix.lower()):
|
275 |
answer = answer[len(prefix):].strip()
|
276 |
+
# Если после удаления префикса остался знак препинания в начале, удаляем его
|
277 |
+
if answer and answer[0] in ",:;.":
|
278 |
+
answer = answer[1:].strip()
|
279 |
+
|
280 |
+
# Удаляем фразы от первого лица
|
281 |
+
first_person_phrases = [
|
282 |
+
"I would say", "I think that", "I believe that", "In my opinion",
|
283 |
+
"From my knowledge", "As far as I know", "I can tell you that",
|
284 |
+
"I can say that", "I'm confident that", "I'm certain that"
|
285 |
+
]
|
286 |
+
|
287 |
+
for phrase in first_person_phrases:
|
288 |
+
if phrase.lower() in answer.lower():
|
289 |
+
answer = answer.lower().replace(phrase.lower(), "").strip()
|
290 |
+
# Восстанавливаем первую букву в верхний регистр, если это было начало предложения
|
291 |
+
if answer:
|
292 |
+
answer = answer[0].upper() + answer[1:]
|
293 |
|
294 |
# Специфическое форматирование в зависимости от типа вопроса
|
295 |
if question_type == "calculation":
|
296 |
+
# Для числовых ответов удаляем лишний текст и оставляем только числа
|
|
|
|
|
297 |
numbers = re.findall(r'-?\d+\.?\d*', answer)
|
298 |
if numbers:
|
299 |
+
# Если есть несколько чисел, берем то, которое выглядит как финальный ответ
|
300 |
+
# (обычно последнее число в тексте)
|
301 |
+
answer = numbers[-1]
|
302 |
+
|
303 |
+
# Удаляем лишние нули после десятичной точки
|
304 |
+
if '.' in answer:
|
305 |
+
answer = answer.rstrip('0').rstrip('.') if '.' in answer else answer
|
306 |
+
|
307 |
elif question_type == "list":
|
308 |
+
# Проверяем, не повторяет ли ответ части вопроса
|
309 |
+
question_words = set(re.findall(r'\b\w+\b', question.lower()))
|
310 |
+
answer_words = set(re.findall(r'\b\w+\b', answer.lower()))
|
311 |
+
|
312 |
+
# Если более 70% слов ответа содержится в воп��осе, это может быть эхо вопроса
|
313 |
+
overlap_ratio = len(answer_words.intersection(question_words)) / len(answer_words) if answer_words else 0
|
314 |
+
|
315 |
+
if overlap_ratio > 0.7:
|
316 |
+
# Пытаемся извлечь список из вопроса
|
317 |
+
list_items = []
|
318 |
+
|
319 |
+
# Ищем конкретные элементы списка в ответе
|
320 |
+
items_match = re.findall(r'(?:^|,\s*)([A-Za-z0-9]+(?:\s+[A-Za-z0-9]+)*)', answer)
|
321 |
+
if items_match:
|
322 |
+
list_items = [item.strip() for item in items_match if item.strip()]
|
323 |
+
|
324 |
+
if list_items:
|
325 |
+
answer = ", ".join(list_items)
|
326 |
+
else:
|
327 |
+
# Если не удалось извлечь элементы, используем заглушку
|
328 |
+
answer = "Items not specified"
|
329 |
+
|
330 |
# Для списков убеждаемся, что элементы разделены запятыми
|
331 |
if "," not in answer and " " in answer:
|
332 |
items = [item.strip() for item in answer.split() if item.strip()]
|
333 |
answer = ", ".join(items)
|
334 |
+
|
335 |
+
# Удаляем "and" перед последним элементом, если есть
|
336 |
+
answer = re.sub(r',?\s+and\s+', ', ', answer)
|
337 |
+
|
338 |
+
elif question_type == "date_time":
|
339 |
+
# Для дат пытаемся привести к стандартному формату
|
340 |
+
date_match = re.search(r'\b\d{1,4}[-/\.]\d{1,2}[-/\.]\d{1,4}\b|\b\d{1,2}\s+(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b|\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b', answer)
|
341 |
+
if date_match:
|
342 |
+
answer = date_match.group(0)
|
343 |
+
|
344 |
+
elif question_type == "name":
|
345 |
+
# Для имен удаляем титулы и дополнительную информацию
|
346 |
+
# Оставляем только имя и фамилию
|
347 |
+
name_match = re.search(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', answer)
|
348 |
+
if name_match:
|
349 |
+
answer = name_match.group(0)
|
350 |
+
|
351 |
+
elif question_type == "location":
|
352 |
+
# Для локаций удаляем дополнительную информацию
|
353 |
+
# Часто локации начинаются с заглавной буквы
|
354 |
+
location_match = re.search(r'\b[A-Z][a-z]+(?:[\s-][A-Z][a-z]+)*\b', answer)
|
355 |
+
if location_match:
|
356 |
+
answer = location_match.group(0)
|
357 |
+
|
358 |
+
elif question_type == "yes_no":
|
359 |
+
# Для да/нет вопросов оставляем только "yes" или "no"
|
360 |
+
answer_lower = answer.lower()
|
361 |
+
if "yes" in answer_lower or "correct" in answer_lower or "true" in answer_lower or "right" in answer_lower:
|
362 |
+
answer = "yes"
|
363 |
+
elif "no" in answer_lower or "incorrect" in answer_lower or "false" in answer_lower or "wrong" in answer_lower:
|
364 |
+
answer = "no"
|
365 |
+
|
366 |
+
elif question_type == "reversed_text":
|
367 |
+
# Для обратного текста, проверяем, не нужно ли нам вернуть обратный ответ
|
368 |
+
if "opposite" in question.lower() and "write" in question.lower():
|
369 |
+
# Если в вопросе просят написать противоположное слово
|
370 |
+
opposites = {
|
371 |
+
"left": "right", "right": "left", "up": "down", "down": "up",
|
372 |
+
"north": "south", "south": "north", "east": "west", "west": "east",
|
373 |
+
"hot": "cold", "cold": "hot", "big": "small", "small": "big",
|
374 |
+
"tall": "short", "short": "tall", "high": "low", "low": "high",
|
375 |
+
"open": "closed", "closed": "open", "on": "off", "off": "on",
|
376 |
+
"in": "out", "out": "in", "yes": "no", "no": "yes"
|
377 |
+
}
|
378 |
+
|
379 |
+
# Ищем слово в ответе, которое может иметь противоположное значение
|
380 |
+
for word, opposite in opposites.items():
|
381 |
+
if word in answer.lower():
|
382 |
+
answer = opposite
|
383 |
+
break
|
384 |
+
|
385 |
+
# Если не нашли противоположное слово, используем значение из словаря
|
386 |
+
if answer == raw_answer.strip():
|
387 |
+
for key, value in REVERSED_TEXT_ANSWERS.items():
|
388 |
+
if key in question.lower():
|
389 |
+
answer = value
|
390 |
+
break
|
391 |
+
|
392 |
+
# Финальная очистка ответа
|
393 |
+
# Удаляем кавычки, если они окружают весь ответ
|
394 |
+
answer = answer.strip('"\'')
|
395 |
+
|
396 |
+
# Удаляем точку в конце, если это не часть числа
|
397 |
+
if answer.endswith('.') and not re.match(r'.*\d\.$', answer):
|
398 |
+
answer = answer[:-1]
|
399 |
+
|
400 |
+
# Удаляем множественные пробелы
|
401 |
+
answer = re.sub(r'\s+', ' ', answer).strip()
|
402 |
+
|
403 |
+
# Проверяем, не является ли ответ определением, которое содержит сам термин
|
404 |
+
if question_type == "definition":
|
405 |
+
# Извлекаем ключевой термин из вопроса
|
406 |
+
term_match = re.search(r"what is ([a-z\s']+)\??|define (?:the term )?['\"]?([a-z\s]+)['\"]?", question.lower())
|
407 |
+
if term_match:
|
408 |
+
term = term_match.group(1) if term_match.group(1) else term_match.group(2)
|
409 |
+
if term and term in answer.lower():
|
410 |
+
# Если определение содержит сам термин, пытаемся его переформулировать
|
411 |
+
answer = answer.lower().replace(term, "it")
|
412 |
+
# Восстанавливаем первую букву в верхний регистр
|
413 |
+
answer = answer[0].upper() + answer[1:]
|
414 |
+
|
415 |
+
# Ограничиваем длину определений
|
416 |
+
if len(answer.split()) > 10:
|
417 |
+
# Берем только первое предложение или первые 10 слов
|
418 |
+
first_sentence = re.split(r'[.!?]', answer)[0]
|
419 |
+
words = first_sentence.split()
|
420 |
+
if len(words) > 10:
|
421 |
+
answer = " ".join(words[:10])
|
422 |
|
423 |
return answer
|
424 |
|
|
|
447 |
print(f"Classified as: {question_type}")
|
448 |
|
449 |
try:
|
450 |
+
# Проверяем наличие готового ответа в словаре фактических коррекций
|
451 |
+
factual_correction = self._check_factual_correction(question, "")
|
452 |
+
if factual_correction:
|
453 |
+
# Формируем JSON-ответ с готовым ответом
|
454 |
+
result = {"final_answer": factual_correction}
|
455 |
+
json_response = json.dumps(result)
|
456 |
+
|
457 |
+
# Сохраняем в кэш
|
458 |
+
if self.use_cache:
|
459 |
+
self.cache[cache_key] = json_response
|
460 |
+
self._save_cache()
|
461 |
+
|
462 |
+
return json_response
|
463 |
+
|
464 |
+
# Создаем специализированный промпт
|
465 |
+
specialized_prompt = self._create_specialized_prompt(question, question_type)
|
466 |
+
|
467 |
# Генерируем ответ с помощью модели
|
468 |
+
inputs = self.tokenizer(specialized_prompt, return_tensors="pt")
|
469 |
+
|
470 |
+
# Настройки генерации для более точных ответов
|
471 |
+
# Примечание: некоторые модели могут не поддерживать все параметры
|
472 |
+
generation_params = {
|
473 |
+
"max_length": 150, # Увеличиваем максимальную длину
|
474 |
+
"num_beams": 5, # Используем beam search для лучших результатов
|
475 |
+
"no_repeat_ngram_size": 2 # Избегаем повторений
|
476 |
+
}
|
477 |
+
|
478 |
+
# Добавляем параметры, которые поддерживаются не всеми моделями
|
479 |
+
try:
|
480 |
+
outputs = self.model.generate(
|
481 |
+
**inputs,
|
482 |
+
**generation_params,
|
483 |
+
temperature=0.7, # Немного случайности для разнообразия
|
484 |
+
top_p=0.95 # Nucleus sampling для более естественных ответов
|
485 |
+
)
|
486 |
+
except:
|
487 |
+
# Если не поддерживаются дополнительные параметры, используем базовые
|
488 |
+
outputs = self.model.generate(**inputs, **generation_params)
|
489 |
+
|
490 |
raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
491 |
|
492 |
+
# Форматируем ответ с учетом типа вопроса и исходного вопроса
|
493 |
+
formatted_answer = self._format_answer(raw_answer, question_type, question)
|
494 |
|
495 |
# Формируем JSON-ответ
|
496 |
result = {"final_answer": formatted_answer}
|
|
|
507 |
error_msg = f"Error generating answer: {e}"
|
508 |
print(error_msg)
|
509 |
return json.dumps({"final_answer": f"AGENT ERROR: {e}"})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|