|
import os |
|
import json |
|
import re |
|
import torch |
|
from typing import Dict, Optional |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
CACHE_FILE = "gaia_answers_cache.json" |
|
DEFAULT_MODEL = "google/flan-t5-base" |
|
|
|
class EnhancedGAIAAgent: |
|
"""Агент для Hugging Face GAIA с улучшенной обработкой вопросов""" |
|
|
|
def __init__(self, model_name=DEFAULT_MODEL, use_cache=False): |
|
print(f"Initializing EnhancedGAIAAgent with model: {model_name}") |
|
self.model_name = model_name |
|
self.use_cache = use_cache |
|
self.cache = self._load_cache() if use_cache else {} |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
def _load_cache(self) -> Dict[str, str]: |
|
if os.path.exists(CACHE_FILE): |
|
try: |
|
with open(CACHE_FILE, 'r', encoding='utf-8') as f: |
|
return json.load(f) |
|
except: |
|
return {} |
|
return {} |
|
|
|
def _save_cache(self) -> None: |
|
try: |
|
with open(CACHE_FILE, 'w', encoding='utf-8') as f: |
|
json.dump(self.cache, f, ensure_ascii=False, indent=2) |
|
except: |
|
pass |
|
|
|
def _classify_question(self, question: str) -> str: |
|
question_lower = question.lower() |
|
|
|
if any(word in question_lower for word in ["calculate", "sum", "how many"]): |
|
return "calculation" |
|
elif any(word in question_lower for word in ["list", "enumerate"]): |
|
return "list" |
|
elif any(word in question_lower for word in ["date", "time", "when"]): |
|
return "date_time" |
|
return "factual" |
|
|
|
def _format_answer(self, raw_answer: str, question_type: str) -> str: |
|
answer = raw_answer.strip() |
|
|
|
|
|
prefixes = ["Answer:", "The answer is:", "I think", "I believe"] |
|
for prefix in prefixes: |
|
if answer.lower().startswith(prefix.lower()): |
|
answer = answer[len(prefix):].strip() |
|
|
|
|
|
if question_type == "calculation": |
|
numbers = re.findall(r'-?\d+\.?\d*', answer) |
|
if numbers: |
|
answer = numbers[0] |
|
elif question_type == "list": |
|
if "," not in answer and " " in answer: |
|
items = [item.strip() for item in answer.split() if item.strip()] |
|
answer = ", ".join(items) |
|
|
|
|
|
answer = answer.strip('"\'') |
|
if answer.endswith('.') and not re.match(r'.*\d\.$', answer): |
|
answer = answer[:-1] |
|
return re.sub(r'\s+', ' ', answer).strip() |
|
|
|
def __call__(self, question: str, task_id: Optional[str] = None) -> str: |
|
cache_key = task_id if task_id else question |
|
if self.use_cache and cache_key in self.cache: |
|
return self.cache[cache_key] |
|
|
|
question_type = self._classify_question(question) |
|
|
|
try: |
|
|
|
inputs = self.tokenizer(question, return_tensors="pt") |
|
outputs = self.model.generate(**inputs, max_length=100) |
|
raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
formatted_answer = self._format_answer(raw_answer, question_type) |
|
|
|
|
|
result = {"final_answer": formatted_answer} |
|
json_response = json.dumps(result) |
|
|
|
if self.use_cache: |
|
self.cache[cache_key] = json_response |
|
self._save_cache() |
|
|
|
return json_response |
|
|
|
except Exception as e: |
|
return json.dumps({"final_answer": f"AGENT ERROR: {e}"}) |