|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from duckduckgo_search import DDGS |
|
import torch |
|
|
|
SYSTEM_PROMPT = """ |
|
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. |
|
Then return only the answer without any explanation or formatting. |
|
Do not say 'Final answer' or anything else. Just output the raw answer string. |
|
""" |
|
|
|
def web_search(query: str, max_results: int = 3) -> list[str]: |
|
results = [] |
|
try: |
|
with DDGS() as ddgs: |
|
for r in ddgs.text(query, max_results=max_results): |
|
snippet = f"{r['title']}: {r['body']} (URL: {r['href']})" |
|
results.append(snippet) |
|
except Exception as e: |
|
results.append(f"[Web search error: {e}]") |
|
return results |
|
|
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id="google/flan-t5-base"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
def __call__(self, question: str) -> tuple[str, str]: |
|
try: |
|
|
|
search_required = any(keyword in question.lower() for keyword in [ |
|
"wikipedia", "who", "when", "where", "youtube", "mp3", "video", "article", "name", "code", "city", "award", "nasa" |
|
]) |
|
|
|
if search_required: |
|
search_results = web_search(question) |
|
context = "\n".join(search_results) |
|
prompt = f"{SYSTEM_PROMPT}\n\nSearch context:\n{context}\n\nQuestion: {question}" |
|
trace = f"Search used:\n{context}" |
|
else: |
|
prompt = f"{SYSTEM_PROMPT}\n\n{question}" |
|
trace = "Search not used." |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
pad_token_id=self.tokenizer.pad_token_id |
|
) |
|
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
final = output_text.strip() |
|
return final, trace |
|
|
|
except Exception as e: |
|
return "ERROR", f"Agent failed: {e}" |
|
|
|
|