File size: 1,815 Bytes
13755f8
 
c3c803a
 
 
13755f8
c3c803a
 
 
13755f8
 
 
 
 
 
 
 
 
 
 
 
5617dda
c36be6c
13755f8
c36be6c
13755f8
 
c36be6c
13755f8
 
c36be6c
13755f8
 
c36be6c
13755f8
 
c36be6c
13755f8
 
c36be6c
13755f8
 
 
 
 
 
 
 
ee62c26
058b8cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import mimetypes
from transformers import pipeline
from tools.asr_tool import transcribe_audio
from tools.excel_tool import analyze_excel
from tools.search_tool import search_duckduckgo

class GaiaAgent:
    def __init__(self):
        token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
        if not token:
            raise ValueError("Missing HUGGINGFACEHUB_API_TOKEN environment variable.")

        self.llm = pipeline(
            "text-generation",
            model="mistralai/Mistral-7B-Instruct-v0.2",
            token=token,
            device="cpu",
            max_new_tokens=512,
            do_sample=False
        )

    def __call__(self, question: str):
        trace_log = ""

        if "http" in question or "www." in question:
            trace_log += "Detected URL or web reference. Performing search...\n"
            result = search_duckduckgo(question)
            trace_log += f"Search result: {result}\n"
            return result, trace_log

        if question.lower().endswith(".mp3") or question.lower().endswith(".wav"):
            trace_log += "Detected audio file. Performing transcription...\n"
            result = transcribe_audio(question)
            trace_log += f"Transcription result: {result}\n"
            return result, trace_log

        if question.lower().endswith(".xlsx") or question.lower().endswith(".xls"):
            trace_log += "Detected Excel file. Performing analysis...\n"
            result = analyze_excel(question)
            trace_log += f"Excel analysis result: {result}\n"
            return result, trace_log

        trace_log += "General question. Using local model...\n"
        response = self.llm(question)[0]["generated_text"]
        trace_log += f"LLM response: {response}\n"

        return response.strip(), trace_log