|
from transformers import pipeline |
|
from tools.search_tool import search_duckduckgo |
|
from tools.asr_tool import transcribe_audio |
|
from tools.excel_tool import analyze_excel |
|
|
|
class GaiaAgent: |
|
def __init__(self): |
|
self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", device="cpu") |
|
|
|
def __call__(self, question: str): |
|
trace = "" |
|
question_lower = question.lower() |
|
|
|
if "http" in question_lower or "wikipedia" in question_lower: |
|
trace += "Detected URL or web reference. Performing search...\n" |
|
result = search_duckduckgo(question) |
|
trace += result |
|
return result.strip(), trace |
|
|
|
elif question_lower.endswith(".mp3"): |
|
trace += "Detected audio reference. Transcribing audio...\n" |
|
result = transcribe_audio(question) |
|
trace += result |
|
return result.strip(), trace |
|
|
|
elif question_lower.endswith(".xlsx") or question_lower.endswith(".csv"): |
|
trace += "Detected Excel file. Analyzing...\n" |
|
result = analyze_excel(question) |
|
trace += result |
|
return result.strip(), trace |
|
|
|
else: |
|
trace += "General question. Using local model...\n" |
|
output = self.llm(question, max_new_tokens=256, do_sample=False)[0]["generated_text"] |
|
trace += output |
|
return output.strip(), trace |
|
|
|
|