FinalTest / agent.py
yoshizen's picture
Create agent.py
d55317d verified
raw
history blame
3.99 kB
import os
import json
import re
import torch
from typing import Dict, Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
CACHE_FILE = "gaia_answers_cache.json"
DEFAULT_MODEL = "google/flan-t5-base"
class EnhancedGAIAAgent:
"""Агент для Hugging Face GAIA с улучшенной обработкой вопросов"""
def __init__(self, model_name=DEFAULT_MODEL, use_cache=False):
print(f"Initializing EnhancedGAIAAgent with model: {model_name}")
self.model_name = model_name
self.use_cache = use_cache
self.cache = self._load_cache() if use_cache else {}
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def _load_cache(self) -> Dict[str, str]:
if os.path.exists(CACHE_FILE):
try:
with open(CACHE_FILE, 'r', encoding='utf-8') as f:
return json.load(f)
except:
return {}
return {}
def _save_cache(self) -> None:
try:
with open(CACHE_FILE, 'w', encoding='utf-8') as f:
json.dump(self.cache, f, ensure_ascii=False, indent=2)
except:
pass
def _classify_question(self, question: str) -> str:
question_lower = question.lower()
if any(word in question_lower for word in ["calculate", "sum", "how many"]):
return "calculation"
elif any(word in question_lower for word in ["list", "enumerate"]):
return "list"
elif any(word in question_lower for word in ["date", "time", "when"]):
return "date_time"
return "factual"
def _format_answer(self, raw_answer: str, question_type: str) -> str:
answer = raw_answer.strip()
# Удаление префиксов
prefixes = ["Answer:", "The answer is:", "I think", "I believe"]
for prefix in prefixes:
if answer.lower().startswith(prefix.lower()):
answer = answer[len(prefix):].strip()
# Специфическое форматирование
if question_type == "calculation":
numbers = re.findall(r'-?\d+\.?\d*', answer)
if numbers:
answer = numbers[0]
elif question_type == "list":
if "," not in answer and " " in answer:
items = [item.strip() for item in answer.split() if item.strip()]
answer = ", ".join(items)
# Финальная очистка
answer = answer.strip('"\'')
if answer.endswith('.') and not re.match(r'.*\d\.$', answer):
answer = answer[:-1]
return re.sub(r'\s+', ' ', answer).strip()
def __call__(self, question: str, task_id: Optional[str] = None) -> str:
cache_key = task_id if task_id else question
if self.use_cache and cache_key in self.cache:
return self.cache[cache_key]
question_type = self._classify_question(question)
try:
# Генерация ответа
inputs = self.tokenizer(question, return_tensors="pt")
outputs = self.model.generate(**inputs, max_length=100)
raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Форматирование
formatted_answer = self._format_answer(raw_answer, question_type)
# Формирование JSON
result = {"final_answer": formatted_answer}
json_response = json.dumps(result)
if self.use_cache:
self.cache[cache_key] = json_response
self._save_cache()
return json_response
except Exception as e:
return json.dumps({"final_answer": f"AGENT ERROR: {e}"})