|
""" |
|
Улучшенный GAIA Agent с поддержкой кэширования ответов и исправленным полем agent_code |
|
""" |
|
|
|
|
|
import os |
|
import json |
|
import time |
|
import torch |
|
import requests |
|
import gradio as gr |
|
import pandas as pd |
|
from huggingface_hub import login |
|
from typing import List, Dict, Any, Optional, Union, Callable, Tuple |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
CACHE_FILE = "gaia_answers_cache.json" |
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
MAX_RETRIES = 3 |
|
RETRY_DELAY = 5 |
|
|
|
class EnhancedGAIAAgent: |
|
""" |
|
Улучшенный агент для Hugging Face GAIA с поддержкой кэширования ответов |
|
""" |
|
|
|
def __init__(self, model_name="google/flan-t5-base", use_cache=True): |
|
""" |
|
Инициализация агента с моделью и кэшем |
|
|
|
Args: |
|
model_name: Название модели для загрузки |
|
use_cache: Использовать ли кэширование ответов |
|
""" |
|
print(f"Initializing EnhancedGAIAAgent with model: {model_name}") |
|
self.model_name = model_name |
|
self.use_cache = use_cache |
|
self.cache = self._load_cache() if use_cache else {} |
|
|
|
|
|
print("Loading tokenizer...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
print("Loading model...") |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
print("Model and tokenizer loaded successfully") |
|
|
|
def _load_cache(self) -> Dict[str, str]: |
|
""" |
|
Загружает кэш ответов из файла |
|
|
|
Returns: |
|
Dict[str, str]: Словарь с кэшированными ответами |
|
""" |
|
if os.path.exists(CACHE_FILE): |
|
try: |
|
with open(CACHE_FILE, 'r', encoding='utf-8') as f: |
|
print(f"Loading cache from {CACHE_FILE}") |
|
return json.load(f) |
|
except Exception as e: |
|
print(f"Error loading cache: {e}") |
|
return {} |
|
else: |
|
print(f"Cache file {CACHE_FILE} not found, creating new cache") |
|
return {} |
|
|
|
def _save_cache(self) -> None: |
|
""" |
|
Сохраняет кэш ответов в файл |
|
""" |
|
try: |
|
with open(CACHE_FILE, 'w', encoding='utf-8') as f: |
|
json.dump(self.cache, f, ensure_ascii=False, indent=2) |
|
print(f"Cache saved to {CACHE_FILE}") |
|
except Exception as e: |
|
print(f"Error saving cache: {e}") |
|
|
|
def _classify_question(self, question: str) -> str: |
|
""" |
|
Классифицирует вопрос по типу для лучшего форматирования ответа |
|
|
|
Args: |
|
question: Текст вопроса |
|
|
|
Returns: |
|
str: Тип вопроса (factual, calculation, list, date_time, etc.) |
|
""" |
|
|
|
question_lower = question.lower() |
|
|
|
if any(word in question_lower for word in ["calculate", "sum", "product", "divide", "multiply", "add", "subtract", "how many"]): |
|
return "calculation" |
|
elif any(word in question_lower for word in ["list", "enumerate", "items", "elements"]): |
|
return "list" |
|
elif any(word in question_lower for word in ["date", "time", "day", "month", "year", "when"]): |
|
return "date_time" |
|
else: |
|
return "factual" |
|
|
|
def _format_answer(self, raw_answer: str, question_type: str) -> str: |
|
""" |
|
Форматирует ответ в соответствии с типом вопроса |
|
|
|
Args: |
|
raw_answer: Необработанный ответ от модели |
|
question_type: Тип вопроса |
|
|
|
Returns: |
|
str: Отформатированный ответ |
|
""" |
|
|
|
answer = raw_answer.strip() |
|
|
|
|
|
prefixes = ["Answer:", "The answer is:", "I think", "I believe", "According to", "Based on"] |
|
for prefix in prefixes: |
|
if answer.startswith(prefix): |
|
answer = answer[len(prefix):].strip() |
|
|
|
|
|
if question_type == "calculation": |
|
|
|
|
|
import re |
|
numbers = re.findall(r'-?\d+\.?\d*', answer) |
|
if numbers: |
|
answer = numbers[0] |
|
elif question_type == "list": |
|
|
|
if "," not in answer and " " in answer: |
|
items = [item.strip() for item in answer.split() if item.strip()] |
|
answer = ", ".join(items) |
|
|
|
return answer |
|
|
|
def __call__(self, question: str, task_id: Optional[str] = None) -> str: |
|
""" |
|
Обрабатывает вопрос и возвращает ответ |
|
|
|
Args: |
|
question: Текст вопроса |
|
task_id: Идентификатор задачи (опционально) |
|
|
|
Returns: |
|
str: Ответ в формате JSON с ключом final_answer |
|
""" |
|
|
|
cache_key = task_id if task_id else question |
|
|
|
|
|
if self.use_cache and cache_key in self.cache: |
|
print(f"Cache hit for question: {question[:50]}...") |
|
return self.cache[cache_key] |
|
|
|
|
|
question_type = self._classify_question(question) |
|
print(f"Processing question: {question[:100]}...") |
|
print(f"Classified as: {question_type}") |
|
|
|
try: |
|
|
|
inputs = self.tokenizer(question, return_tensors="pt") |
|
outputs = self.model.generate(**inputs, max_length=100) |
|
raw_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
formatted_answer = self._format_answer(raw_answer, question_type) |
|
|
|
|
|
result = {"final_answer": formatted_answer} |
|
json_response = json.dumps(result) |
|
|
|
|
|
if self.use_cache: |
|
self.cache[cache_key] = json_response |
|
self._save_cache() |
|
|
|
return json_response |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating answer: {e}" |
|
print(error_msg) |
|
return json.dumps({"final_answer": f"AGENT ERROR: {e}"}) |
|
|
|
|
|
class EvaluationRunner: |
|
""" |
|
Обрабатывает процесс оценки: получение вопросов, запуск агента, |
|
и отправку ответов на сервер оценки. |
|
""" |
|
|
|
def __init__(self, api_url=DEFAULT_API_URL): |
|
"""Инициализация с API endpoints.""" |
|
self.api_url = api_url |
|
self.questions_url = f"{api_url}/questions" |
|
self.submit_url = f"{api_url}/submit" |
|
self.results_url = f"{api_url}/results" |
|
self.correct_answers = 0 |
|
self.total_questions = 0 |
|
|
|
def run_evaluation(self, |
|
agent: Callable[[str], str], |
|
username: str, |
|
agent_code_url: str) -> tuple[str, pd.DataFrame]: |
|
""" |
|
Запускает полный процесс оценки: |
|
1. Получает вопросы |
|
2. Запускает агента на всех вопросах |
|
3. Отправляет ответы |
|
4. Возвращает результаты |
|
""" |
|
|
|
questions_data = self._fetch_questions() |
|
if isinstance(questions_data, str): |
|
return questions_data, None |
|
|
|
|
|
results_log, answers_payload = self._run_agent_on_questions(agent, questions_data) |
|
if not answers_payload: |
|
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) |
|
|
|
|
|
submission_result = self._submit_answers(username, agent_code_url, answers_payload) |
|
|
|
|
|
return submission_result, pd.DataFrame(results_log) |
|
|
|
def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]: |
|
"""Получает вопросы с сервера оценки.""" |
|
print(f"Fetching questions from: {self.questions_url}") |
|
try: |
|
response = requests.get(self.questions_url, timeout=15) |
|
response.raise_for_status() |
|
questions_data = response.json() |
|
|
|
if not questions_data: |
|
error_msg = "Fetched questions list is empty or invalid format." |
|
print(error_msg) |
|
return error_msg |
|
|
|
self.total_questions = len(questions_data) |
|
print(f"Successfully fetched {self.total_questions} questions.") |
|
return questions_data |
|
|
|
except requests.exceptions.RequestException as e: |
|
error_msg = f"Error fetching questions: {e}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
except requests.exceptions.JSONDecodeError as e: |
|
error_msg = f"Error decoding JSON response from questions endpoint: {e}" |
|
print(error_msg) |
|
print(f"Response text: {response.text[:500]}") |
|
return error_msg |
|
|
|
except Exception as e: |
|
error_msg = f"An unexpected error occurred fetching questions: {e}" |
|
print(error_msg) |
|
return error_msg |
|
|
|
def _run_agent_on_questions(self, |
|
agent: Any, |
|
questions_data: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: |
|
"""Запускает агента на всех вопросах и собирает результаты.""" |
|
results_log = [] |
|
answers_payload = [] |
|
|
|
print(f"Running agent on {len(questions_data)} questions...") |
|
for item in questions_data: |
|
task_id = item.get("task_id") |
|
question_text = item.get("question") |
|
|
|
if not task_id or question_text is None: |
|
print(f"Skipping item with missing task_id or question: {item}") |
|
continue |
|
|
|
try: |
|
|
|
json_response = agent(question_text, task_id) |
|
|
|
|
|
response_obj = json.loads(json_response) |
|
|
|
|
|
submitted_answer = response_obj.get("final_answer", "") |
|
|
|
answers_payload.append({ |
|
"task_id": task_id, |
|
"submitted_answer": submitted_answer |
|
}) |
|
|
|
results_log.append({ |
|
"Task ID": task_id, |
|
"Question": question_text, |
|
"Submitted Answer": submitted_answer, |
|
"Full Response": json_response |
|
}) |
|
except Exception as e: |
|
print(f"Error running agent on task {task_id}: {e}") |
|
results_log.append({ |
|
"Task ID": task_id, |
|
"Question": question_text, |
|
"Submitted Answer": f"AGENT ERROR: {e}" |
|
}) |
|
|
|
return results_log, answers_payload |
|
|
|
def _submit_answers(self, |
|
username: str, |
|
agent_code_url: str, |
|
answers_payload: List[Dict[str, Any]]) -> str: |
|
"""Отправляет ответы на сервер оценки.""" |
|
|
|
submission_data = { |
|
"username": username.strip(), |
|
"agent_code": agent_code_url.strip(), |
|
"answers": answers_payload |
|
} |
|
|
|
print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}") |
|
max_retries = MAX_RETRIES |
|
retry_delay = RETRY_DELAY |
|
|
|
for attempt in range(1, max_retries + 1): |
|
try: |
|
print(f"Submission attempt {attempt} of {max_retries}...") |
|
response = requests.post( |
|
self.submit_url, |
|
json=submission_data, |
|
headers={"Content-Type": "application/json"}, |
|
timeout=30 |
|
) |
|
response.raise_for_status() |
|
|
|
try: |
|
result = response.json() |
|
score = result.get("score") |
|
max_score = result.get("max_score") |
|
|
|
if score is not None and max_score is not None: |
|
self.correct_answers = score |
|
return f"Evaluation complete! Score: {score}/{max_score}" |
|
else: |
|
print(f"Received N/A results. Waiting {retry_delay} seconds before retry...") |
|
time.sleep(retry_delay) |
|
continue |
|
|
|
except requests.exceptions.JSONDecodeError: |
|
print(f"Submission attempt {attempt}: Response was not JSON. Response: {response.text}") |
|
if attempt < max_retries: |
|
print(f"Waiting {retry_delay} seconds before retry...") |
|
time.sleep(retry_delay) |
|
else: |
|
return f"Submission successful, but response was not JSON. Response: {response.text}" |
|
|
|
except requests.exceptions.RequestException as e: |
|
print(f"Submission attempt {attempt} failed: {e}") |
|
if attempt < max_retries: |
|
print(f"Waiting {retry_delay} seconds before retry...") |
|
time.sleep(retry_delay) |
|
else: |
|
return f"Error submitting answers after {max_retries} attempts: {e}" |
|
|
|
|
|
return "Submission Successful, but results are pending!" |
|
|
|
def _check_results(self, username: str) -> None: |
|
"""Проверяет результаты для подсчета правильных ответов.""" |
|
try: |
|
results_url = f"{self.results_url}?username={username}" |
|
print(f"Checking results at: {results_url}") |
|
|
|
response = requests.get(results_url, timeout=15) |
|
if response.status_code == 200: |
|
try: |
|
data = response.json() |
|
if isinstance(data, dict): |
|
score = data.get("score") |
|
if score is not None: |
|
self.correct_answers = int(score) |
|
print(f"✓ Correct answers: {self.correct_answers}/{self.total_questions}") |
|
else: |
|
print("Score information not available in results") |
|
else: |
|
print("Results data is not in expected format") |
|
except: |
|
print("Could not parse results JSON") |
|
else: |
|
print(f"Could not fetch results, status code: {response.status_code}") |
|
except Exception as e: |
|
print(f"Error checking results: {e}") |
|
|
|
def get_correct_answers_count(self) -> int: |
|
"""Возвращает количество правильных ответов.""" |
|
return self.correct_answers |
|
|
|
def get_total_questions_count(self) -> int: |
|
"""Возвращает общее количество вопросов.""" |
|
return self.total_questions |
|
|
|
def print_evaluation_summary(self, username: str) -> None: |
|
"""Выводит сводку результатов оценки.""" |
|
print("\n===== EVALUATION SUMMARY =====") |
|
print(f"User: {username}") |
|
print(f"Overall Score: {self.correct_answers}/{self.total_questions}") |
|
print(f"Correct Answers: {self.correct_answers}") |
|
print(f"Total Questions: {self.total_questions}") |
|
print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%") |
|
print("=============================\n") |
|
|
|
|
|
def run_evaluation(username: str, |
|
agent_code_url: str, |
|
model_name: str = "google/flan-t5-small", |
|
use_cache: bool = True) -> Tuple[str, int, int, str, str, str]: |
|
""" |
|
Запускает полный процесс оценки с поддержкой кэширования |
|
|
|
Args: |
|
username: Имя пользователя Hugging Face |
|
agent_code_url: URL кода агента (или код агента) |
|
model_name: Название модели для использования |
|
use_cache: Использовать ли кэширование ответов |
|
|
|
Returns: |
|
Tuple[str, int, int, str, str, str]: Кортеж из 6 значений: |
|
- result_text: Текстовый результат оценки |
|
- correct_answers: Количество правильных ответов |
|
- total_questions: Общее количество вопросов |
|
- elapsed_time: Время выполнения |
|
- results_url: URL для проверки результатов |
|
- cache_status: Статус кэширования |
|
""" |
|
start_time = time.time() |
|
|
|
|
|
agent = EnhancedGAIAAgent(model_name=model_name, use_cache=use_cache) |
|
|
|
|
|
runner = EvaluationRunner(api_url=DEFAULT_API_URL) |
|
|
|
|
|
result, results_log = runner.run_evaluation(agent, username, agent_code_url) |
|
|
|
|
|
runner._check_results(username) |
|
|
|
|
|
runner.print_evaluation_summary(username) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
elapsed_time_str = f"{elapsed_time:.2f} seconds" |
|
|
|
|
|
results_url = f"{DEFAULT_API_URL}/results?username={username}" |
|
|
|
|
|
cache_status = "Cache enabled and used" if use_cache else "Cache disabled" |
|
|
|
|
|
return ( |
|
result, |
|
runner.get_correct_answers_count(), |
|
runner.get_total_questions_count(), |
|
elapsed_time_str, |
|
results_url, |
|
cache_status |
|
) |
|
|
|
|
|
def create_gradio_interface(): |
|
""" |
|
Создает Gradio интерфейс для запуска оценки |
|
""" |
|
with gr.Blocks(title="GAIA Agent Evaluation") as demo: |
|
gr.Markdown("# GAIA Agent Evaluation with Caching") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
username = gr.Textbox(label="Hugging Face Username") |
|
agent_code_url = gr.Textbox(label="Agent Code URL or Code", lines=10) |
|
model_name = gr.Dropdown( |
|
label="Model", |
|
choices=["google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large"], |
|
value="google/flan-t5-small" |
|
) |
|
use_cache = gr.Checkbox(label="Use Answer Cache", value=True) |
|
|
|
run_button = gr.Button("Run Evaluation & Submit All Answers") |
|
|
|
with gr.Column(): |
|
result_text = gr.Textbox(label="Result", lines=2) |
|
correct_answers = gr.Number(label="Correct Answers") |
|
total_questions = gr.Number(label="Total Questions") |
|
elapsed_time = gr.Textbox(label="Elapsed Time") |
|
results_url = gr.Textbox(label="Results URL") |
|
cache_status = gr.Textbox(label="Cache Status") |
|
|
|
run_button.click( |
|
fn=run_evaluation, |
|
inputs=[username, agent_code_url, model_name, use_cache], |
|
outputs=[ |
|
result_text, |
|
correct_answers, |
|
total_questions, |
|
elapsed_time, |
|
results_url, |
|
cache_status |
|
] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo = create_gradio_interface() |
|
demo.launch(share=True) |
|
|