yoshizen commited on
Commit
d55317d
·
verified ·
1 Parent(s): eed83d0

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +101 -0
agent.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import torch
5
+ from typing import Dict, Optional
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+
8
+ CACHE_FILE = "gaia_answers_cache.json"
9
+ DEFAULT_MODEL = "google/flan-t5-base"
10
+
11
+ class EnhancedGAIAAgent:
12
+ """Агент для Hugging Face GAIA с улучшенной обработкой вопросов"""
13
+
14
+ def __init__(self, model_name=DEFAULT_MODEL, use_cache=False):
15
+ print(f"Initializing EnhancedGAIAAgent with model: {model_name}")
16
+ self.model_name = model_name
17
+ self.use_cache = use_cache
18
+ self.cache = self._load_cache() if use_cache else {}
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
20
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
21
+
22
+ def _load_cache(self) -> Dict[str, str]:
23
+ if os.path.exists(CACHE_FILE):
24
+ try:
25
+ with open(CACHE_FILE, 'r', encoding='utf-8') as f:
26
+ return json.load(f)
27
+ except:
28
+ return {}
29
+ return {}
30
+
31
+ def _save_cache(self) -> None:
32
+ try:
33
+ with open(CACHE_FILE, 'w', encoding='utf-8') as f:
34
+ json.dump(self.cache, f, ensure_ascii=False, indent=2)
35
+ except:
36
+ pass
37
+
38
+ def _classify_question(self, question: str) -> str:
39
+ question_lower = question.lower()
40
+
41
+ if any(word in question_lower for word in ["calculate", "sum", "how many"]):
42
+ return "calculation"
43
+ elif any(word in question_lower for word in ["list", "enumerate"]):
44
+ return "list"
45
+ elif any(word in question_lower for word in ["date", "time", "when"]):
46
+ return "date_time"
47
+ return "factual"
48
+
49
+ def _format_answer(self, raw_answer: str, question_type: str) -> str:
50
+ answer = raw_answer.strip()
51
+
52
+ # Удаление префиксов
53
+ prefixes = ["Answer:", "The answer is:", "I think", "I believe"]
54
+ for prefix in prefixes:
55
+ if answer.lower().startswith(prefix.lower()):
56
+ answer = answer[len(prefix):].strip()
57
+
58
+ # Специфическое форматирование
59
+ if question_type == "calculation":
60
+ numbers = re.findall(r'-?\d+\.?\d*', answer)
61
+ if numbers:
62
+ answer = numbers[0]
63
+ elif question_type == "list":
64
+ if "," not in answer and " " in answer:
65
+ items = [item.strip() for item in answer.split() if item.strip()]
66
+ answer = ", ".join(items)
67
+
68
+ # Финальная очистка
69
+ answer = answer.strip('"\'')
70
+ if answer.endswith('.') and not re.match(r'.*\d\.$', answer):
71
+ answer = answer[:-1]
72
+ return re.sub(r'\s+', ' ', answer).strip()
73
+
74
+ def __call__(self, question: str, task_id: Optional[str] = None) -> str:
75
+ cache_key = task_id if task_id else question
76
+ if self.use_cache and cache_key in self.cache:
77
+ return self.cache[cache_key]
78
+
79
+ question_type = self._classify_question(question)
80
+
81
+ try:
82
+ # Генерация ответа
83
+ inputs = self.tokenizer(question, return_tensors="pt")
84
+ outputs = self.model.generate(**inputs, max_length=100)
85
+ raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
86
+
87
+ # Форматирование
88
+ formatted_answer = self._format_answer(raw_answer, question_type)
89
+
90
+ # Формирование JSON
91
+ result = {"final_answer": formatted_answer}
92
+ json_response = json.dumps(result)
93
+
94
+ if self.use_cache:
95
+ self.cache[cache_key] = json_response
96
+ self._save_cache()
97
+
98
+ return json_response
99
+
100
+ except Exception as e:
101
+ return json.dumps({"final_answer": f"AGENT ERROR: {e}"})