import torch from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq ) from duckduckgo_search import DDGS import pandas as pd import os SYSTEM_PROMPT = """ You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations. """ class GaiaAgent: def __init__(self, model_id="google/flan-t5-base"): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) # Whisper via HF self.asr_model_id = "openai/whisper-small" self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id) self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device) self.pipe = pipeline( "automatic-speech-recognition", model=self.asr_model, tokenizer=self.asr_processor.tokenizer, feature_extractor=self.asr_processor.feature_extractor, return_timestamps=False, device=0 if torch.cuda.is_available() else -1 ) def search(self, query: str) -> str: try: with DDGS() as ddgs: results = ddgs.text("your query", max_results=1) if results: return results[0]['body'] except Exception: return "" return "" def transcribe_audio(self, file_path: str) -> str: try: result = self.pipe(file_path) return result['text'] except Exception: return "" def handle_excel(self, file_path: str) -> str: try: df = pd.read_excel(file_path) df.columns = [col.lower() for col in df.columns] if 'category' in df.columns and 'sales' in df.columns: food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum() return f"{food_sales:.2f}" except Exception: return "" return "" def __call__(self, question: str, files: dict = None) -> tuple[str, str]: try: context = "" if files: for filename, filepath in files.items(): if filename.endswith(".mp3") or filename.endswith(".wav"): context = self.transcribe_audio(filepath) break elif filename.endswith(".xlsx"): excel_result = self.handle_excel(filepath) return excel_result.strip(), excel_result.strip() elif "http" in question.lower() or "wikipedia" in question.lower(): context = self.search(question) prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}" inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) outputs = self.model.generate( **inputs, max_new_tokens=128, do_sample=False, pad_token_id=self.tokenizer.pad_token_id ) output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) final = output_text.strip() return final, final except Exception as e: return "ERROR", f"Agent failed: {e}"