Freddolin commited on
Commit
9bf47dc
·
verified ·
1 Parent(s): e6bb0db

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +33 -4
agent.py CHANGED
@@ -1,10 +1,26 @@
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
2
  import torch
3
 
4
  SYSTEM_PROMPT = """
5
- You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. Then return only the answer without any explanation or formatting. Do not say 'Final answer' or anything else. Just output the raw answer string.
 
 
6
  """
7
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class GaiaAgent:
9
  def __init__(self, model_id="google/flan-t5-base"):
10
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -14,18 +30,31 @@ class GaiaAgent:
14
 
15
  def __call__(self, question: str) -> tuple[str, str]:
16
  try:
17
- prompt = f"{SYSTEM_PROMPT}\n\n{question}"
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
19
  outputs = self.model.generate(
20
  **inputs,
21
  max_new_tokens=128,
22
  do_sample=False,
23
- temperature=0.0,
24
  pad_token_id=self.tokenizer.pad_token_id
25
  )
26
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
27
  final = output_text.strip()
28
- return final, output_text
 
29
  except Exception as e:
30
  return "ERROR", f"Agent failed: {e}"
31
 
 
1
+ # --- agent.py ---
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from duckduckgo_search import DDGS
4
  import torch
5
 
6
  SYSTEM_PROMPT = """
7
+ You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
8
+ Then return only the answer without any explanation or formatting.
9
+ Do not say 'Final answer' or anything else. Just output the raw answer string.
10
  """
11
 
12
+ def web_search(query: str, max_results: int = 3) -> list[str]:
13
+ results = []
14
+ try:
15
+ with DDGS() as ddgs:
16
+ for r in ddgs.text(query, max_results=max_results):
17
+ snippet = f"{r['title']}: {r['body']} (URL: {r['href']})"
18
+ results.append(snippet)
19
+ except Exception as e:
20
+ results.append(f"[Web search error: {e}]")
21
+ return results
22
+
23
+
24
  class GaiaAgent:
25
  def __init__(self, model_id="google/flan-t5-base"):
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
 
30
 
31
  def __call__(self, question: str) -> tuple[str, str]:
32
  try:
33
+ # Heuristik: gör webbsök om frågan kräver externa fakta
34
+ search_required = any(keyword in question.lower() for keyword in [
35
+ "wikipedia", "who", "when", "where", "youtube", "mp3", "video", "article", "name", "code", "city", "award", "nasa"
36
+ ])
37
+
38
+ if search_required:
39
+ search_results = web_search(question)
40
+ context = "\n".join(search_results)
41
+ prompt = f"{SYSTEM_PROMPT}\n\nSearch context:\n{context}\n\nQuestion: {question}"
42
+ trace = f"Search used:\n{context}"
43
+ else:
44
+ prompt = f"{SYSTEM_PROMPT}\n\n{question}"
45
+ trace = "Search not used."
46
+
47
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
48
  outputs = self.model.generate(
49
  **inputs,
50
  max_new_tokens=128,
51
  do_sample=False,
 
52
  pad_token_id=self.tokenizer.pad_token_id
53
  )
54
  output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
55
  final = output_text.strip()
56
+ return final, trace
57
+
58
  except Exception as e:
59
  return "ERROR", f"Agent failed: {e}"
60