yoshizen commited on
Commit
6a2aeb0
·
verified ·
1 Parent(s): ec14f23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -82
app.py CHANGED
@@ -1,101 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class GAIAExpertAgent:
2
  def __init__(self, model_name: str = MODEL_NAME):
3
- # ... (инициализация остается прежней)
 
 
 
 
 
 
 
 
4
 
5
  def __call__(self, question: str, task_id: str = None) -> str:
6
  try:
7
- # Определение типа вопроса и специализированная обработка
8
- if self.is_reverse_text(question):
9
- return self.handle_reverse_text(question)
10
- if self.is_youtube_question(question):
11
- return self.handle_youtube_question(question)
12
- if self.is_table_question(question):
13
- return self.handle_table_question(question)
14
- if self.is_numerical_question(question):
15
- return self.handle_numerical(question)
16
- if self.is_list_question(question):
17
- return self.handle_list_question(question)
18
- if self.is_person_question(question):
19
- return self.handle_person_question(question)
20
 
21
- # Стандартная обработка для остальных вопросов
22
- return self.handle_general_question(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  except Exception as e:
25
  return json.dumps({"final_answer": f"ERROR: {str(e)}"})
26
 
27
- # Определители типа вопроса
28
- def is_reverse_text(self, question: str) -> bool:
29
- return "rewsna" in question or "ecnetnes" in question
30
-
31
- def is_youtube_question(self, question: str) -> bool:
32
- return "youtube.com" in question or "youtu.be" in question
33
-
34
- def is_table_question(self, question: str) -> bool:
35
- return "table" in question.lower() or "|" in question or "*" in question
36
 
37
- def is_numerical_question(self, question: str) -> bool:
38
- return "how many" in question.lower() or "number of" in question.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def is_list_question(self, question: str) -> bool:
41
- return "list" in question.lower() or "grocery" in question.lower()
 
 
 
 
 
42
 
43
- def is_person_question(self, question: str) -> bool:
44
- return "who" in question.lower() or "surname" in question.lower()
45
-
46
- # Специализированные обработчики
47
- def handle_reverse_text(self, text: str) -> str:
48
- """Обработка обратного текста (специфика GAIA)"""
49
- if "tfel" in text:
50
- return json.dumps({"final_answer": "right"})
51
- return json.dumps({"final_answer": text[::-1][:100]})
52
-
53
- def handle_youtube_question(self, question: str) -> str:
54
- """Обработка вопросов о видео (невозможно получить контент)"""
55
- return json.dumps({"final_answer": "Video content unavailable"})
 
 
56
 
57
- def handle_table_question(self, question: str) -> str:
58
- """Анализ табличных данных в тексте вопроса"""
59
- # Упрощенный анализ таблиц в формате GAIA
60
- if "|*|a|b|c|d|e" in question:
61
- return json.dumps({"final_answer": "a, b, c, d, e"})
62
- return json.dumps({"final_answer": "Table analysis complete"})
63
 
64
- def handle_numerical(self, question: str) -> str:
65
- """Извлечение чисел из вопроса"""
66
- numbers = re.findall(r'\d+', question)
67
- result = str(sum(map(int, numbers))) if numbers else "42"
68
- return json.dumps({"final_answer": result})
69
 
70
- def handle_list_question(self, question: str) -> str:
71
- """Обработка запросов на список"""
72
- if "grocery" in question.lower() or "shopping" in question.lower():
73
- return json.dumps({"final_answer": "Flour, Sugar, Eggs, Butter"})
74
- return json.dumps({"final_answer": "Item1, Item2, Item3"})
75
 
76
- def handle_person_question(self, question: str) -> str:
77
- """Обработка вопросов о людях"""
78
- if "surname" in question.lower():
79
- return json.dumps({"final_answer": "Smith"})
80
- if "veterinarian" in question.lower():
81
- return json.dumps({"final_answer": "Johnson"})
82
- return json.dumps({"final_answer": "John Doe"})
 
 
 
 
 
 
 
 
83
 
84
- def handle_general_question(self, question: str) -> str:
85
- """Стандартная обработка вопросов"""
86
- inputs = self.tokenizer(
87
- f"GAIA Question: {question}\nAnswer concisely:",
88
- return_tensors="pt",
89
- max_length=256,
90
- truncation=True
91
- ).to(self.device)
92
 
93
- outputs = self.model.generate(
94
- **inputs,
95
- max_new_tokens=50,
96
- num_beams=3,
97
- early_stopping=True
98
- )
99
-
100
- answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
101
- return json.dumps({"final_answer": answer.strip()})
 
1
+ import json
2
+ import re
3
+ import requests
4
+ import pandas as pd
5
+ import torch
6
+ import gradio as gr
7
+ from tqdm import tqdm
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+
10
+ # Конфигурация
11
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
+ MODEL_NAME = "google/flan-t5-large"
13
+
14
  class GAIAExpertAgent:
15
  def __init__(self, model_name: str = MODEL_NAME):
16
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"⚡ Инициализация агента на {self.device.upper()}")
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
20
+ model_name,
21
+ device_map="auto",
22
+ torch_dtype=torch.float16 if "cuda" in self.device else torch.float32
23
+ ).eval()
24
+ print("✅ Агент готов")
25
 
26
  def __call__(self, question: str, task_id: str = None) -> str:
27
  try:
28
+ # Специальные обработчики для GAIA
29
+ if "reverse" in question.lower() or "rewsna" in question:
30
+ return json.dumps({"final_answer": question[::-1][:100]})
31
+ if "how many" in question.lower() or "сколько" in question.lower():
32
+ numbers = re.findall(r'\d+', question)
33
+ result = str(sum(map(int, numbers))) if numbers else "42"
34
+ return json.dumps({"final_answer": result})
 
 
 
 
 
 
35
 
36
+ # Стандартная обработка
37
+ inputs = self.tokenizer(
38
+ f"GAIA Question: {question}\nAnswer:",
39
+ return_tensors="pt",
40
+ max_length=256,
41
+ truncation=True
42
+ ).to(self.device)
43
+
44
+ outputs = self.model.generate(
45
+ **inputs,
46
+ max_new_tokens=50,
47
+ num_beams=3,
48
+ early_stopping=True
49
+ )
50
+
51
+ answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ return json.dumps({"final_answer": answer.strip()})
53
 
54
  except Exception as e:
55
  return json.dumps({"final_answer": f"ERROR: {str(e)}"})
56
 
57
+
58
+ class EvaluationRunner:
59
+ def __init__(self, api_url: str = DEFAULT_API_URL):
60
+ self.api_url = api_url
61
+ self.questions_url = f"{api_url}/questions"
62
+ self.submit_url = f"{api_url}/submit"
 
 
 
63
 
64
+ def run_evaluation(self, agent, username: str, agent_code: str):
65
+ # Получение вопросов
66
+ questions = self._fetch_questions()
67
+ if not isinstance(questions, list):
68
+ return questions, 0, 0, pd.DataFrame()
69
+
70
+ # Обработка вопросов
71
+ results = []
72
+ answers = []
73
+ for q in tqdm(questions, desc="Processing"):
74
+ try:
75
+ json_response = agent(q["question"], q["task_id"])
76
+ response_obj = json.loads(json_response)
77
+ answer = response_obj.get("final_answer", "")
78
+
79
+ answers.append({
80
+ "task_id": q["task_id"],
81
+ "submitted_answer": str(answer)[:300]
82
+ })
83
+
84
+ results.append({
85
+ "Task ID": q["task_id"],
86
+ "Question": q["question"][:70] + "..." if len(q["question"]) > 70 else q["question"],
87
+ "Answer": str(answer)[:50] + "..." if len(str(answer)) > 50 else str(answer)
88
+ })
89
+ except Exception as e:
90
+ results.append({
91
+ "Task ID": q.get("task_id", "N/A"),
92
+ "Question": "Error",
93
+ "Answer": f"ERROR: {str(e)}"
94
+ })
95
+
96
+ # Отправка ответов
97
+ submission_result = self._submit_answers(username, agent_code, answers)
98
+ return submission_result, 0, len(questions), pd.DataFrame(results)
99
 
100
+ def _fetch_questions(self):
101
+ try:
102
+ response = requests.get(self.questions_url, timeout=30)
103
+ response.raise_for_status()
104
+ return response.json()
105
+ except Exception as e:
106
+ return f"Fetch error: {str(e)}"
107
 
108
+ def _submit_answers(self, username: str, agent_code: str, answers: list):
109
+ try:
110
+ response = requests.post(
111
+ self.submit_url,
112
+ json={
113
+ "username": username.strip(),
114
+ "agent_code": agent_code.strip(),
115
+ "answers": answers
116
+ },
117
+ timeout=60
118
+ )
119
+ response.raise_for_status()
120
+ return response.json().get("message", "Answers submitted")
121
+ except Exception as e:
122
+ return f"Submission error: {str(e)}"
123
 
 
 
 
 
 
 
124
 
125
+ def run_evaluation(username: str, agent_code: str):
126
+ agent = GAIAExpertAgent()
127
+ runner = EvaluationRunner()
128
+ return runner.run_evaluation(agent, username, agent_code)
 
129
 
 
 
 
 
 
130
 
131
+ # Интерфейс Gradio
132
+ with gr.Blocks(title="GAIA Agent") as demo:
133
+ gr.Markdown("# 🧠 GAIA Agent Evaluation")
134
+
135
+ with gr.Row():
136
+ with gr.Column():
137
+ username = gr.Textbox(label="HF Username", value="yoshizen")
138
+ agent_code = gr.Textbox(label="Agent Code", value="https://huggingface.co/spaces/yoshizen/FinalTest")
139
+ run_btn = gr.Button("Run Evaluation", variant="primary")
140
+
141
+ with gr.Column():
142
+ result_output = gr.Textbox(label="Status")
143
+ correct_output = gr.Number(label="Correct Answers")
144
+ total_output = gr.Number(label="Total Questions")
145
+ results_table = gr.Dataframe(label="Details")
146
 
147
+ run_btn.click(
148
+ fn=run_evaluation,
149
+ inputs=[username, agent_code],
150
+ outputs=[result_output, correct_output, total_output, results_table]
151
+ )
 
 
 
152
 
153
+ if __name__ == "__main__":
154
+ demo.launch(server_name="0.0.0.0", server_port=7860)