Freddolin commited on
Commit
9666d9f
·
verified ·
1 Parent(s): 059c116

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +20 -16
agent.py CHANGED
@@ -1,4 +1,5 @@
1
- from transformers import pipeline
 
2
 
3
  SYSTEM_PROMPT = """
4
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
@@ -9,28 +10,31 @@ If you are asked for a string, don't use articles, neither abbreviations (e.g. f
9
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
10
  """
11
 
12
-
13
  class GaiaAgent:
14
- def __init__(self, model_id="google/flan-t5-base"):
15
- self.generator = pipeline(
16
- "text2text-generation",
17
- model=model_id,
18
- tokenizer=model_id,
19
- max_new_tokens=512,
20
- device="cpu"
21
- )
22
 
23
  def __call__(self, question: str) -> tuple[str, str]:
24
  try:
25
  prompt = f"{SYSTEM_PROMPT}\n\n{question}\nFINAL ANSWER:"
26
- output = self.generator(prompt)[0]["generated_text"]
 
 
 
 
 
 
 
 
27
 
28
- if "FINAL ANSWER:" in output:
29
- final = output.split("FINAL ANSWER:")[-1].strip().split("\n")[0].strip()
30
  else:
31
- final = output.strip()
32
- return final, output
33
  except Exception as e:
34
  return "ERROR", f"Agent failed: {e}"
35
-
36
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ import torch
3
 
4
  SYSTEM_PROMPT = """
5
  You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template:
 
10
  If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
11
  """
12
 
 
13
  class GaiaAgent:
14
+ def __init__(self, model_id="tiiuae/falcon-rw-1b"):
15
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
16
+ self.model = AutoModelForCausalLM.from_pretrained(model_id)
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ self.model.to(self.device)
 
 
 
19
 
20
  def __call__(self, question: str) -> tuple[str, str]:
21
  try:
22
  prompt = f"{SYSTEM_PROMPT}\n\n{question}\nFINAL ANSWER:"
23
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
24
+ outputs = self.model.generate(
25
+ **inputs,
26
+ max_new_tokens=256,
27
+ do_sample=True,
28
+ temperature=0.7,
29
+ pad_token_id=self.tokenizer.eos_token_id
30
+ )
31
+ output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
 
33
+ if "FINAL ANSWER:" in output_text:
34
+ final = output_text.split("FINAL ANSWER:")[-1].strip().split("\n")[0].strip()
35
  else:
36
+ final = output_text.strip()
37
+ return final, output_text
38
  except Exception as e:
39
  return "ERROR", f"Agent failed: {e}"
 
40