File size: 1,586 Bytes
230477c
 
 
 
 
 
 
 
 
 
 
58c4724
230477c
059c116
58c4724
059c116
58c4724
 
 
059c116
58c4724
 
230477c
 
b5d03d2
230477c
 
 
 
 
 
 
 
 
b5d03d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers import pipeline

SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: 
FINAL ANSWER: [YOUR FINAL ANSWER]. 
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. 
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. 
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. 
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""


class GaiaAgent:
    def __init__(self, model_id="google/flan-t5-base"):
        self.generator = pipeline(
            "text2text-generation",
            model=model_id,
            tokenizer=model_id,
            max_new_tokens=512,
            device="cpu"
        )

    def __call__(self, question: str) -> tuple[str, str]:
        try:
            prompt = f"{SYSTEM_PROMPT}\n\n{question}\nFINAL ANSWER:"
            output = self.generator(prompt)[0]["generated_text"]

            if "FINAL ANSWER:" in output:
                final = output.split("FINAL ANSWER:")[-1].strip().split("\n")[0].strip()
            else:
                final = output.strip()
            return final, output
        except Exception as e:
            return "ERROR", f"Agent failed: {e}"