|
from transformers import pipeline |
|
from tools.asr_tool import transcribe_audio |
|
from tools.excel_tool import analyze_excel |
|
from tools.search_tool import search_duckduckgo |
|
import mimetypes |
|
|
|
class GaiaAgent: |
|
def __init__(self): |
|
print("Loading model...") |
|
self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base") |
|
|
|
def __call__(self, question: str): |
|
trace = "" |
|
|
|
|
|
if question.lower().strip().endswith(('.mp3', '.wav')): |
|
trace += "Audio detected. Running transcription...\n" |
|
text = transcribe_audio(question.strip()) |
|
trace += f"Transcribed text: {text}\n" |
|
answer = self.qa_pipeline(text, max_new_tokens=64)[0]['generated_text'] |
|
return answer.strip(), trace |
|
|
|
|
|
if question.lower().strip().endswith(('.xls', '.xlsx')): |
|
trace += "Excel detected. Running analysis...\n" |
|
answer = analyze_excel(question.strip()) |
|
trace += f"Extracted value: {answer}\n" |
|
return answer.strip(), trace |
|
|
|
|
|
if any(keyword in question.lower() for keyword in ["wikipedia", "video", "youtube", "article"]): |
|
trace += "Performing DuckDuckGo search...\n" |
|
summary = search_duckduckgo(question) |
|
trace += f"Summary from search: {summary}\n" |
|
answer = self.qa_pipeline(summary + "\n" + question, max_new_tokens=64)[0]['generated_text'] |
|
return answer.strip(), trace |
|
|
|
trace += "General question. Using local model...\n" |
|
answer = self.qa_pipeline(question, max_new_tokens=64)[0]['generated_text'] |
|
return answer.strip(), trace |
|
|