Freddolin's picture
Update agent.py
e258602 verified
raw
history blame
3.64 kB
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline,
AutoProcessor,
AutoModelForSpeechSeq2Seq
)
from ddgs import DDGS
import pandas as pd
import os
SYSTEM_PROMPT = """
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
"""
class GaiaAgent:
def __init__(self, model_id="google/flan-t5-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Whisper via HF
self.asr_model_id = "openai/whisper-small"
self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id)
self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device)
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.asr_model,
tokenizer=self.asr_processor.tokenizer,
feature_extractor=self.asr_processor.feature_extractor,
return_timestamps=False,
device=0 if torch.cuda.is_available() else -1
)
def search(self, query: str) -> str:
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, safesearch="off"))
if results:
return results[0]['body']
except Exception:
return ""
return ""
def transcribe_audio(self, file_path: str) -> str:
try:
result = self.pipe(file_path)
return result['text']
except Exception:
return ""
def handle_excel(self, file_path: str) -> str:
try:
df = pd.read_excel(file_path)
df.columns = [col.lower() for col in df.columns]
if 'category' in df.columns and 'sales' in df.columns:
food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
return f"{food_sales:.2f}"
except Exception:
return ""
return ""
def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
try:
context = ""
if files:
for filename, filepath in files.items():
if filename.endswith(".mp3") or filename.endswith(".wav"):
context = self.transcribe_audio(filepath)
break
elif filename.endswith(".xlsx"):
excel_result = self.handle_excel(filepath)
return excel_result.strip(), excel_result.strip()
elif "http" in question.lower() or "wikipedia" in question.lower():
context = self.search(question)
prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id
)
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
final = output_text.strip()
return final, final
except Exception as e:
return "ERROR", f"Agent failed: {e}"