|
import os |
|
import mimetypes |
|
from transformers import pipeline |
|
from tools.asr_tool import transcribe_audio |
|
from tools.excel_tool import analyze_excel |
|
from tools.search_tool import search_duckduckgo |
|
|
|
class GaiaAgent: |
|
def __init__(self): |
|
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
if not token: |
|
raise ValueError("Missing HUGGINGFACEHUB_API_TOKEN environment variable.") |
|
|
|
self.llm = pipeline( |
|
"text-generation", |
|
model="mistralai/Mistral-7B-Instruct-v0.2", |
|
token=token, |
|
device="cpu", |
|
max_new_tokens=512, |
|
do_sample=False |
|
) |
|
|
|
def __call__(self, question: str): |
|
trace_log = "" |
|
|
|
if "http" in question or "www." in question: |
|
trace_log += "Detected URL or web reference. Performing search...\n" |
|
result = search_duckduckgo(question) |
|
trace_log += f"Search result: {result}\n" |
|
return result, trace_log |
|
|
|
if question.lower().endswith(".mp3") or question.lower().endswith(".wav"): |
|
trace_log += "Detected audio file. Performing transcription...\n" |
|
result = transcribe_audio(question) |
|
trace_log += f"Transcription result: {result}\n" |
|
return result, trace_log |
|
|
|
if question.lower().endswith(".xlsx") or question.lower().endswith(".xls"): |
|
trace_log += "Detected Excel file. Performing analysis...\n" |
|
result = analyze_excel(question) |
|
trace_log += f"Excel analysis result: {result}\n" |
|
return result, trace_log |
|
|
|
trace_log += "General question. Using local model...\n" |
|
response = self.llm(question)[0]["generated_text"] |
|
trace_log += f"LLM response: {response}\n" |
|
|
|
return response.strip(), trace_log |
|
|
|
|