|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
SYSTEM_PROMPT = """ |
|
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. Then return only the answer without any explanation or formatting. Do not say 'Final answer' or anything else. Just output the raw answer string. |
|
""" |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id="google/flan-t5-base"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def __call__(self, question: str) -> tuple[str, str]: |
|
try: |
|
prompt = f"{SYSTEM_PROMPT}\n\n{question}" |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
temperature=0.0, |
|
pad_token_id=self.tokenizer.pad_token_id |
|
) |
|
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
final = output_text.strip() |
|
return final, output_text |
|
except Exception as e: |
|
return "ERROR", f"Agent failed: {e}" |
|
|
|
|