File size: 1,431 Bytes
c3c803a
c36be6c
c3c803a
 
 
 
 
c36be6c
5617dda
c36be6c
c3c803a
c36be6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5617dda
c3c803a
c36be6c
 
 
 
ee62c26
058b8cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import pipeline
from tools.search_tool import search_duckduckgo
from tools.asr_tool import transcribe_audio
from tools.excel_tool import analyze_excel

class GaiaAgent:
    def __init__(self):
        self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", device="cpu")

    def __call__(self, question: str):
        trace = ""
        question_lower = question.lower()

        if "http" in question_lower or "wikipedia" in question_lower:
            trace += "Detected URL or web reference. Performing search...\n"
            result = search_duckduckgo(question)
            trace += result
            return result.strip(), trace

        elif question_lower.endswith(".mp3"):
            trace += "Detected audio reference. Transcribing audio...\n"
            result = transcribe_audio(question)
            trace += result
            return result.strip(), trace

        elif question_lower.endswith(".xlsx") or question_lower.endswith(".csv"):
            trace += "Detected Excel file. Analyzing...\n"
            result = analyze_excel(question)
            trace += result
            return result.strip(), trace

        else:
            trace += "General question. Using local model...\n"
            output = self.llm(question, max_new_tokens=256, do_sample=False)[0]["generated_text"]
            trace += output
            return output.strip(), trace