import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from ddgs import DDGS import re import pandas as pd import tempfile import os import whisper SYSTEM_PROMPT = """ You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer. Then return only the answer without any explanation or formatting. Do not say 'Final answer' or anything else. Just output the raw answer string. """ class GaiaAgent: def __init__(self, model_id="google/flan-t5-base"): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.transcriber = whisper.load_model("base") def search(self, query: str) -> str: try: with DDGS() as ddgs: results = list(ddgs.text(query, safesearch="off")) if results: return results[0]['body'] except Exception as e: return f"Search failed: {e}" return "" def transcribe_audio(self, file_path: str) -> str: try: result = self.transcriber.transcribe(file_path) return result['text'] except Exception as e: return f"Audio transcription failed: {e}" def handle_excel(self, file_path: str) -> str: try: df = pd.read_excel(file_path) food_sales = df[df['Category'].str.lower() != 'drink']['Sales'].sum() return f"{food_sales:.2f}" except Exception as e: return f"Excel parsing failed: {e}" def __call__(self, question: str, files: dict = None) -> tuple[str, str]: try: if "http" in question or "Wikipedia" in question: web_context = self.search(question) prompt = f"{SYSTEM_PROMPT}\n\n{web_context}\n\nQuestion: {question}" elif files: file_keys = list(files.keys()) for key in file_keys: if key.endswith(".mp3"): audio_txt = self.transcribe_audio(files[key]) prompt = f"{SYSTEM_PROMPT}\n\n{audio_txt}\n\n{question}" break elif key.endswith(".xlsx"): excel_result = self.handle_excel(files[key]) return excel_result, excel_result else: prompt = f"{SYSTEM_PROMPT}\n\n{question}" else: prompt = f"{SYSTEM_PROMPT}\n\n{question}" inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) outputs = self.model.generate( **inputs, max_new_tokens=128, do_sample=False, temperature=0.0, pad_token_id=self.tokenizer.pad_token_id ) output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) final = output_text.strip() return final, output_text except Exception as e: return "ERROR", f"Agent failed: {e}"