File size: 2,050 Bytes
9666d9f
 
230477c
 
 
 
 
 
 
 
 
 
 
9666d9f
 
 
 
 
58c4724
230477c
 
b5d03d2
9666d9f
 
 
 
 
 
 
 
 
230477c
9666d9f
 
230477c
9666d9f
 
230477c
 
b5d03d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: 
FINAL ANSWER: [YOUR FINAL ANSWER]. 
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. 
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. 
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. 
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""

class GaiaAgent:
    def __init__(self, model_id="tiiuae/falcon-rw-1b"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(model_id)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, question: str) -> tuple[str, str]:
        try:
            prompt = f"{SYSTEM_PROMPT}\n\n{question}\nFINAL ANSWER:"
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7,
                pad_token_id=self.tokenizer.eos_token_id
            )
            output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            if "FINAL ANSWER:" in output_text:
                final = output_text.split("FINAL ANSWER:")[-1].strip().split("\n")[0].strip()
            else:
                final = output_text.strip()
            return final, output_text
        except Exception as e:
            return "ERROR", f"Agent failed: {e}"