File size: 1,806 Bytes
058b8cf
ee62c26
 
 
 
230477c
 
ee62c26
 
058b8cf
 
 
 
 
ee62c26
058b8cf
 
 
 
 
 
 
 
 
 
 
ee62c26
058b8cf
 
 
 
 
ee62c26
058b8cf
 
 
 
 
 
 
 
 
ee62c26
058b8cf
ee62c26
058b8cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import mimetypes
from transformers import pipeline
from tools.asr_tool import transcribe_audio
from tools.excel_tool import analyze_excel
from tools.search_tool import search_duckduckgo

class GaiaAgent:
    def __init__(self):
        print("Loading model...")
        self.llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", max_new_tokens=512, device="cpu")

    def __call__(self, question: str, files: list = None):
        trace = []
        context = ""

        if files:
            for file in files:
                mime, _ = mimetypes.guess_type(file.name)
                if mime and mime.startswith("audio"):
                    transcription = transcribe_audio(file.name)
                    trace.append(f"Transcribed audio: {transcription}")
                    context += f"\nTranscription: {transcription}"
                elif mime and ("spreadsheet" in mime or file.name.endswith(".xlsx")):
                    result = analyze_excel(file.name)
                    trace.append(f"Excel analysis: {result}")
                    context += f"\nSpreadsheet data: {result}"

        if "http" in question or "Wikipedia" in question or "YouTube" in question or "search" in question.lower():
            trace.append("Performing DuckDuckGo search...")
            search_result = search_duckduckgo(question)
            trace.append(f"Summary from search: {search_result}")
            context += f"\nSearch Result: {search_result}"

        # Include the original question
        prompt = f"""
Answer the question based on the context below.
Context: {context}
Question: {question}
Answer:
"""
        response = self.llm(prompt)[0]['generated_text'].split("Answer:")[-1].strip()
        trace.append(response)

        return response, "\n".join(trace)