|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM, |
|
pipeline, |
|
AutoProcessor, |
|
AutoModelForSpeechSeq2Seq |
|
) |
|
from duckduckgo_search import DDGS |
|
import pandas as pd |
|
import os |
|
|
|
SYSTEM_PROMPT = """ |
|
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations. |
|
""" |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id="google/flan-t5-base"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model.to(self.device) |
|
|
|
|
|
self.asr_model_id = "openai/whisper-small" |
|
self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id) |
|
self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device) |
|
self.pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=self.asr_model, |
|
tokenizer=self.asr_processor.tokenizer, |
|
feature_extractor=self.asr_processor.feature_extractor, |
|
return_timestamps=False, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
def search(self, query: str) -> str: |
|
try: |
|
with DDGS() as ddgs: |
|
results = ddgs.text("your query", max_results=1) |
|
if results: |
|
return results[0]['body'] |
|
except Exception: |
|
return "" |
|
return "" |
|
|
|
def transcribe_audio(self, file_path: str) -> str: |
|
try: |
|
result = self.pipe(file_path) |
|
return result['text'] |
|
except Exception: |
|
return "" |
|
|
|
def handle_excel(self, file_path: str) -> str: |
|
try: |
|
df = pd.read_excel(file_path) |
|
df.columns = [col.lower() for col in df.columns] |
|
if 'category' in df.columns and 'sales' in df.columns: |
|
food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum() |
|
return f"{food_sales:.2f}" |
|
except Exception: |
|
return "" |
|
return "" |
|
|
|
def __call__(self, question: str, files: dict = None) -> tuple[str, str]: |
|
try: |
|
context = "" |
|
if files: |
|
for filename, filepath in files.items(): |
|
if filename.endswith(".mp3") or filename.endswith(".wav"): |
|
context = self.transcribe_audio(filepath) |
|
break |
|
elif filename.endswith(".xlsx"): |
|
excel_result = self.handle_excel(filepath) |
|
return excel_result.strip(), excel_result.strip() |
|
elif "http" in question.lower() or "wikipedia" in question.lower(): |
|
context = self.search(question) |
|
|
|
prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}" |
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=128, |
|
do_sample=False, |
|
pad_token_id=self.tokenizer.pad_token_id |
|
) |
|
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
final = output_text.strip() |
|
return final, final |
|
except Exception as e: |
|
return "ERROR", f"Agent failed: {e}" |
|
|