File size: 1,716 Bytes
ee62c26
 
 
 
 
230477c
 
ee62c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from transformers import pipeline
from tools.asr_tool import transcribe_audio
from tools.excel_tool import analyze_excel
from tools.search_tool import search_duckduckgo
import mimetypes

class GaiaAgent:
    def __init__(self):
        print("Loading model...")
        self.qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-base")

    def __call__(self, question: str):
        trace = ""

        # Handle audio
        if question.lower().strip().endswith(('.mp3', '.wav')):
            trace += "Audio detected. Running transcription...\n"
            text = transcribe_audio(question.strip())
            trace += f"Transcribed text: {text}\n"
            answer = self.qa_pipeline(text, max_new_tokens=64)[0]['generated_text']
            return answer.strip(), trace

        # Handle Excel
        if question.lower().strip().endswith(('.xls', '.xlsx')):
            trace += "Excel detected. Running analysis...\n"
            answer = analyze_excel(question.strip())
            trace += f"Extracted value: {answer}\n"
            return answer.strip(), trace

        # Handle web search
        if any(keyword in question.lower() for keyword in ["wikipedia", "video", "youtube", "article"]):
            trace += "Performing DuckDuckGo search...\n"
            summary = search_duckduckgo(question)
            trace += f"Summary from search: {summary}\n"
            answer = self.qa_pipeline(summary + "\n" + question, max_new_tokens=64)[0]['generated_text']
            return answer.strip(), trace

        trace += "General question. Using local model...\n"
        answer = self.qa_pipeline(question, max_new_tokens=64)[0]['generated_text']
        return answer.strip(), trace