File size: 3,101 Bytes
9666d9f 1d782b2 230477c 2477b72 230477c 2b8dbad 9666d9f 2b8dbad 9666d9f 1d782b2 2477b72 1d782b2 58c4724 1d782b2 230477c 1d782b2 2477b72 1d782b2 2477b72 1d782b2 2477b72 1d782b2 2477b72 9bf47dc 2477b72 2b8dbad 9666d9f 2b8dbad 1d782b2 2b8dbad 9666d9f 2b8dbad 2477b72 230477c 2b8dbad b5d03d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ddgs import DDGS
import re
import pandas as pd
import tempfile
import os
import whisper
SYSTEM_PROMPT = """
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
"""
class GaiaAgent:
def __init__(self, model_id="google/flan-t5-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.transcriber = whisper.load_model("base")
def search(self, query: str) -> str:
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, safesearch="off"))
if results:
return results[0]['body']
except Exception as e:
return ""
return ""
def transcribe_audio(self, file_path: str) -> str:
try:
result = self.transcriber.transcribe(file_path)
return result['text']
except Exception:
return ""
def handle_excel(self, file_path: str) -> str:
try:
df = pd.read_excel(file_path)
df.columns = [col.lower() for col in df.columns]
if 'category' in df.columns and 'sales' in df.columns:
food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
return f"{food_sales:.2f}"
except Exception:
return ""
return ""
def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
try:
context = ""
if files:
for filename, filepath in files.items():
if filename.endswith(".mp3"):
context = self.transcribe_audio(filepath)
break
elif filename.endswith(".xlsx"):
excel_result = self.handle_excel(filepath)
return excel_result.strip(), excel_result.strip()
elif "http" in question.lower() or "wikipedia" in question.lower():
context = self.search(question)
prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
temperature=0.0,
pad_token_id=self.tokenizer.pad_token_id
)
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
final = output_text.strip()
return final, final
except Exception as e:
return "ERROR", f"Agent failed: {e}"
|