Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
pipeline, | |
AutoProcessor, | |
AutoModelForSpeechSeq2Seq | |
) | |
from ddgs import DDGS | |
import pandas as pd | |
import os | |
SYSTEM_PROMPT = """ | |
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations. | |
""" | |
class GaiaAgent: | |
def __init__(self, model_id="google/flan-t5-base"): | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
# Whisper via HF | |
self.asr_model_id = "openai/whisper-small" | |
self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id) | |
self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device) | |
self.pipe = pipeline( | |
"automatic-speech-recognition", | |
model=self.asr_model, | |
tokenizer=self.asr_processor.tokenizer, | |
feature_extractor=self.asr_processor.feature_extractor, | |
return_timestamps=False, | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
def search(self, query: str) -> str: | |
try: | |
with DDGS() as ddgs: | |
results = list(ddgs.text(query, safesearch="off")) | |
if results: | |
return results[0]['body'] | |
except Exception: | |
return "" | |
return "" | |
def transcribe_audio(self, file_path: str) -> str: | |
try: | |
result = self.pipe(file_path) | |
return result['text'] | |
except Exception: | |
return "" | |
def handle_excel(self, file_path: str) -> str: | |
try: | |
df = pd.read_excel(file_path) | |
df.columns = [col.lower() for col in df.columns] | |
if 'category' in df.columns and 'sales' in df.columns: | |
food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum() | |
return f"{food_sales:.2f}" | |
except Exception: | |
return "" | |
return "" | |
def __call__(self, question: str, files: dict = None) -> tuple[str, str]: | |
try: | |
context = "" | |
if files: | |
for filename, filepath in files.items(): | |
if filename.endswith(".mp3") or filename.endswith(".wav"): | |
context = self.transcribe_audio(filepath) | |
break | |
elif filename.endswith(".xlsx"): | |
excel_result = self.handle_excel(filepath) | |
return excel_result.strip(), excel_result.strip() | |
elif "http" in question.lower() or "wikipedia" in question.lower(): | |
context = self.search(question) | |
prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}" | |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=128, | |
do_sample=False, | |
pad_token_id=self.tokenizer.pad_token_id | |
) | |
output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
final = output_text.strip() | |
return final, final | |
except Exception as e: | |
return "ERROR", f"Agent failed: {e}" | |