Update agent.py
Browse files
agent.py
CHANGED
@@ -1,40 +1,32 @@
|
|
1 |
-
from transformers import AutoTokenizer,
|
2 |
import torch
|
3 |
|
4 |
SYSTEM_PROMPT = """
|
5 |
-
You are a general AI assistant. I will ask you a question.
|
6 |
-
FINAL ANSWER: [YOUR FINAL ANSWER].
|
7 |
-
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
8 |
-
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
9 |
-
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
10 |
-
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
11 |
"""
|
12 |
|
13 |
class GaiaAgent:
|
14 |
-
def __init__(self, model_id="
|
15 |
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
16 |
-
self.model =
|
17 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
self.model.to(self.device)
|
19 |
|
20 |
def __call__(self, question: str) -> tuple[str, str]:
|
21 |
try:
|
22 |
-
prompt = f"{SYSTEM_PROMPT}\n\n{question}
|
23 |
-
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
24 |
outputs = self.model.generate(
|
25 |
**inputs,
|
26 |
-
max_new_tokens=
|
27 |
-
do_sample=
|
28 |
-
temperature=0.
|
29 |
-
pad_token_id=self.tokenizer.
|
30 |
)
|
31 |
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
32 |
-
|
33 |
-
if "FINAL ANSWER:" in output_text:
|
34 |
-
final = output_text.split("FINAL ANSWER:")[-1].strip().split("\n")[0].strip()
|
35 |
-
else:
|
36 |
-
final = output_text.strip()
|
37 |
return final, output_text
|
38 |
except Exception as e:
|
39 |
return "ERROR", f"Agent failed: {e}"
|
|
|
40 |
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
3 |
|
4 |
SYSTEM_PROMPT = """
|
5 |
+
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. Then return only the answer without any explanation or formatting. Do not say 'Final answer' or anything else. Just output the raw answer string.
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
|
8 |
class GaiaAgent:
|
9 |
+
def __init__(self, model_id="google/flan-t5-base"):
|
10 |
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
11 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
12 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
self.model.to(self.device)
|
14 |
|
15 |
def __call__(self, question: str) -> tuple[str, str]:
|
16 |
try:
|
17 |
+
prompt = f"{SYSTEM_PROMPT}\n\n{question}"
|
18 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
|
19 |
outputs = self.model.generate(
|
20 |
**inputs,
|
21 |
+
max_new_tokens=128,
|
22 |
+
do_sample=False,
|
23 |
+
temperature=0.0,
|
24 |
+
pad_token_id=self.tokenizer.pad_token_id
|
25 |
)
|
26 |
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
27 |
+
final = output_text.strip()
|
|
|
|
|
|
|
|
|
28 |
return final, output_text
|
29 |
except Exception as e:
|
30 |
return "ERROR", f"Agent failed: {e}"
|
31 |
+
|
32 |
|