Freddolin's picture
Update agent.py
2477b72 verified
raw
history blame
3.1 kB
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ddgs import DDGS
import re
import pandas as pd
import tempfile
import os
import whisper
SYSTEM_PROMPT = """
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
"""
class GaiaAgent:
def __init__(self, model_id="google/flan-t5-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.transcriber = whisper.load_model("base")
def search(self, query: str) -> str:
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, safesearch="off"))
if results:
return results[0]['body']
except Exception as e:
return ""
return ""
def transcribe_audio(self, file_path: str) -> str:
try:
result = self.transcriber.transcribe(file_path)
return result['text']
except Exception:
return ""
def handle_excel(self, file_path: str) -> str:
try:
df = pd.read_excel(file_path)
df.columns = [col.lower() for col in df.columns]
if 'category' in df.columns and 'sales' in df.columns:
food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
return f"{food_sales:.2f}"
except Exception:
return ""
return ""
def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
try:
context = ""
if files:
for filename, filepath in files.items():
if filename.endswith(".mp3"):
context = self.transcribe_audio(filepath)
break
elif filename.endswith(".xlsx"):
excel_result = self.handle_excel(filepath)
return excel_result.strip(), excel_result.strip()
elif "http" in question.lower() or "wikipedia" in question.lower():
context = self.search(question)
prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
pad_token_id=self.tokenizer.pad_token_id
)
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
final = output_text.strip()
return final, final
except Exception as e:
return "ERROR", f"Agent failed: {e}"