Freddolin's picture
Update agent.py
1d782b2 verified
raw
history blame
3.27 kB
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ddgs import DDGS
import re
import pandas as pd
import tempfile
import os
import whisper
SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
Then return only the answer without any explanation or formatting.
Do not say 'Final answer' or anything else. Just output the raw answer string.
"""
class GaiaAgent:
def __init__(self, model_id="google/flan-t5-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.transcriber = whisper.load_model("base")
def search(self, query: str) -> str:
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, safesearch="off"))
if results:
return results[0]['body']
except Exception as e:
return f"Search failed: {e}"
return ""
def transcribe_audio(self, file_path: str) -> str:
try:
result = self.transcriber.transcribe(file_path)
return result['text']
except Exception as e:
return f"Audio transcription failed: {e}"
def handle_excel(self, file_path: str) -> str:
try:
df = pd.read_excel(file_path)
food_sales = df[df['Category'].str.lower() != 'drink']['Sales'].sum()
return f"{food_sales:.2f}"
except Exception as e:
return f"Excel parsing failed: {e}"
def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
try:
if "http" in question or "Wikipedia" in question:
web_context = self.search(question)
prompt = f"{SYSTEM_PROMPT}\n\n{web_context}\n\nQuestion: {question}"
elif files:
file_keys = list(files.keys())
for key in file_keys:
if key.endswith(".mp3"):
audio_txt = self.transcribe_audio(files[key])
prompt = f"{SYSTEM_PROMPT}\n\n{audio_txt}\n\n{question}"
break
elif key.endswith(".xlsx"):
excel_result = self.handle_excel(files[key])
return excel_result, excel_result
else:
prompt = f"{SYSTEM_PROMPT}\n\n{question}"
else:
prompt = f"{SYSTEM_PROMPT}\n\n{question}"
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
pad_token_id=self.tokenizer.pad_token_id
)
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
final = output_text.strip()
return final, output_text
except Exception as e:
return "ERROR", f"Agent failed: {e}"