File size: 2,469 Bytes
9bf47dc
2b8dbad
9bf47dc
9666d9f
230477c
 
9bf47dc
 
 
230477c
 
9bf47dc
 
 
 
 
 
 
 
 
 
 
 
230477c
2b8dbad
9666d9f
2b8dbad
9666d9f
 
58c4724
230477c
 
9bf47dc
 
 
 
 
 
 
 
 
 
 
 
 
 
2b8dbad
9666d9f
 
2b8dbad
 
 
9666d9f
 
2b8dbad
9bf47dc
 
230477c
 
2b8dbad
b5d03d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# --- agent.py ---
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from duckduckgo_search import DDGS
import torch

SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
Then return only the answer without any explanation or formatting.
Do not say 'Final answer' or anything else. Just output the raw answer string.
"""

def web_search(query: str, max_results: int = 3) -> list[str]:
    results = []
    try:
        with DDGS() as ddgs:
            for r in ddgs.text(query, max_results=max_results):
                snippet = f"{r['title']}: {r['body']} (URL: {r['href']})"
                results.append(snippet)
    except Exception as e:
        results.append(f"[Web search error: {e}]")
    return results


class GaiaAgent:
    def __init__(self, model_id="google/flan-t5-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, question: str) -> tuple[str, str]:
        try:
            # Heuristik: gör webbsök om frågan kräver externa fakta
            search_required = any(keyword in question.lower() for keyword in [
                "wikipedia", "who", "when", "where", "youtube", "mp3", "video", "article", "name", "code", "city", "award", "nasa"
            ])

            if search_required:
                search_results = web_search(question)
                context = "\n".join(search_results)
                prompt = f"{SYSTEM_PROMPT}\n\nSearch context:\n{context}\n\nQuestion: {question}"
                trace = f"Search used:\n{context}"
            else:
                prompt = f"{SYSTEM_PROMPT}\n\n{question}"
                trace = "Search not used."

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id
            )
            output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            final = output_text.strip()
            return final, trace

        except Exception as e:
            return "ERROR", f"Agent failed: {e}"