import os import json import re import torch from typing import Dict, Optional from transformers import AutoTokenizer, AutoModelForSeq2SeqLM CACHE_FILE = "gaia_answers_cache.json" DEFAULT_MODEL = "google/flan-t5-base" class EnhancedGAIAAgent: """Агент для Hugging Face GAIA с улучшенной обработкой вопросов""" def __init__(self, model_name=DEFAULT_MODEL, use_cache=False): print(f"Initializing EnhancedGAIAAgent with model: {model_name}") self.model_name = model_name self.use_cache = use_cache self.cache = self._load_cache() if use_cache else {} self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) def _load_cache(self) -> Dict[str, str]: if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, 'r', encoding='utf-8') as f: return json.load(f) except: return {} return {} def _save_cache(self) -> None: try: with open(CACHE_FILE, 'w', encoding='utf-8') as f: json.dump(self.cache, f, ensure_ascii=False, indent=2) except: pass def _classify_question(self, question: str) -> str: question_lower = question.lower() if any(word in question_lower for word in ["calculate", "sum", "how many"]): return "calculation" elif any(word in question_lower for word in ["list", "enumerate"]): return "list" elif any(word in question_lower for word in ["date", "time", "when"]): return "date_time" return "factual" def _format_answer(self, raw_answer: str, question_type: str) -> str: answer = raw_answer.strip() # Удаление префиксов prefixes = ["Answer:", "The answer is:", "I think", "I believe"] for prefix in prefixes: if answer.lower().startswith(prefix.lower()): answer = answer[len(prefix):].strip() # Специфическое форматирование if question_type == "calculation": numbers = re.findall(r'-?\d+\.?\d*', answer) if numbers: answer = numbers[0] elif question_type == "list": if "," not in answer and " " in answer: items = [item.strip() for item in answer.split() if item.strip()] answer = ", ".join(items) # Финальная очистка answer = answer.strip('"\'') if answer.endswith('.') and not re.match(r'.*\d\.$', answer): answer = answer[:-1] return re.sub(r'\s+', ' ', answer).strip() def __call__(self, question: str, task_id: Optional[str] = None) -> str: cache_key = task_id if task_id else question if self.use_cache and cache_key in self.cache: return self.cache[cache_key] question_type = self._classify_question(question) try: # Генерация ответа inputs = self.tokenizer(question, return_tensors="pt") outputs = self.model.generate(**inputs, max_length=100) raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Форматирование formatted_answer = self._format_answer(raw_answer, question_type) # Формирование JSON result = {"final_answer": formatted_answer} json_response = json.dumps(result) if self.use_cache: self.cache[cache_key] = json_response self._save_cache() return json_response except Exception as e: return json.dumps({"final_answer": f"AGENT ERROR: {e}"})