Freddolin's picture
Update agent.py
24816c1 verified
raw
history blame
1.89 kB
import os
from transformers import pipeline
from tools.asr_tool import transcribe_audio
from tools.excel_tool import analyze_excel
from tools.search_tool import search_duckduckgo
class GaiaAgent:
def __init__(self):
token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
if not token:
raise ValueError("Missing HUGGINGFACEHUB_API_TOKEN environment variable.")
self.llm = pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
use_auth_token=token,
device="cpu",
max_new_tokens=512,
do_sample=False
)
def __call__(self, question: str):
trace_log = ""
# Search tool
if "http" in question or "www." in question:
trace_log += "Detected URL or web reference. Performing search...\n"
result = search_duckduckgo(question)
trace_log += f"Search result: {result}\n"
return result, trace_log
# Audio tool
if question.lower().endswith(".mp3") or question.lower().endswith(".wav"):
trace_log += "Detected audio file. Performing transcription...\n"
result = transcribe_audio(question)
trace_log += f"Transcription result: {result}\n"
return result, trace_log
# Excel tool
if question.lower().endswith(".xlsx") or question.lower().endswith(".xls"):
trace_log += "Detected Excel file. Performing analysis...\n"
result = analyze_excel(question)
trace_log += f"Excel analysis result: {result}\n"
return result, trace_log
# LLM fallback
trace_log += "General question. Using local model...\n"
response = self.llm(question)[0]["generated_text"]
trace_log += f"LLM response: {response}\n"
return response.strip(), trace_log