|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
from ddgs import DDGS |
|
import re |
|
import pandas as pd |
|
import tempfile |
|
import os |
|
import whisper |
|
|
|
|
|
SYSTEM_PROMPT = """ |
|
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. |
|
Then return only the answer without any explanation or formatting. |
|
Do not say 'Final answer' or anything else. Just output the raw answer string. |
|
""" |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id="google/flan-t5-base"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
self.transcriber = whisper.load_model("base") |
|
|
|
def search(self, query: str) -> str: |
|
try: |
|
with DDGS() as ddgs: |
|
results = list(ddgs.text(query, safesearch="off")) |
|
if results: |
|
return results[0]['body'] |
|
except Exception as e: |
|
return f"Search failed: {e}" |
|
return "" |
|
|
|
def transcribe_audio(self, file_path: str) -> str: |
|
try: |
|
result = self.transcriber.transcribe(file_path) |
|
return result['text'] |
|
except Exception as e: |
|
return f"Audio transcription failed: {e}" |
|
|
|
def handle_excel(self, file_path: str) -> str: |
|
try: |
|
df = pd.read_excel(file_path) |
|
food_sales = df[df['Category'].str.lower() != 'drink']['Sales'].sum() |
|
return f"{food_sales:.2f}" |
|
except Exception as e: |
|
return f"Excel parsing failed: {e}" |
|
|
|
def __call__(self, question: str, files: dict = None) -> tuple[str, str]: |
|
try: |
|
if "http" in question or "Wikipedia" in question: |
|
web_context = self.search(question) |
|
prompt = f"{SYSTEM_PROMPT}\n\n{web_context}\n\nQuestion: {question}" |
|
elif files: |
|
file_keys = list(files.keys()) |
|
for key in file_keys: |
|
if key.endswith(".mp3"): |
|
audio_txt = self.transcribe_audio(files[key]) |
|
prompt = f"{SYSTEM_PROMPT}\n\n{audio_txt}\n\n{question}" |
|
break |
|
elif key.endswith(".xlsx"): |
|
excel_result = self.handle_excel(files[key]) |
|
return excel_result, excel_result |
|
else: |
|
prompt = f"{SYSTEM_PROMPT}\n\n{question}" |
|
else: |
|
prompt = f"{SYSTEM_PROMPT}\n\n{question}" |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
temperature=0.0, |
|
pad_token_id=self.tokenizer.pad_token_id |
|
) |
|
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
final = output_text.strip() |
|
return final, output_text |
|
except Exception as e: |
|
return "ERROR", f"Agent failed: {e}" |
|
|
|
|