File size: 3,273 Bytes
9666d9f
1d782b2
 
 
 
 
 
 
 
230477c
 
9bf47dc
 
 
230477c
 
 
2b8dbad
9666d9f
2b8dbad
9666d9f
 
1d782b2
 
 
 
 
 
 
 
 
 
 
58c4724
1d782b2
230477c
1d782b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf47dc
 
 
2b8dbad
9666d9f
 
2b8dbad
 
1d782b2
2b8dbad
9666d9f
 
2b8dbad
1d782b2
230477c
 
2b8dbad
b5d03d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ddgs import DDGS
import re
import pandas as pd
import tempfile
import os
import whisper


SYSTEM_PROMPT = """
You are a general AI assistant. I will ask you a question. Think step by step to find the best possible answer.
Then return only the answer without any explanation or formatting.
Do not say 'Final answer' or anything else. Just output the raw answer string.
"""

class GaiaAgent:
    def __init__(self, model_id="google/flan-t5-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.transcriber = whisper.load_model("base")

    def search(self, query: str) -> str:
        try:
            with DDGS() as ddgs:
                results = list(ddgs.text(query, safesearch="off"))
                if results:
                    return results[0]['body']
        except Exception as e:
            return f"Search failed: {e}"
        return ""

    def transcribe_audio(self, file_path: str) -> str:
        try:
            result = self.transcriber.transcribe(file_path)
            return result['text']
        except Exception as e:
            return f"Audio transcription failed: {e}"

    def handle_excel(self, file_path: str) -> str:
        try:
            df = pd.read_excel(file_path)
            food_sales = df[df['Category'].str.lower() != 'drink']['Sales'].sum()
            return f"{food_sales:.2f}"
        except Exception as e:
            return f"Excel parsing failed: {e}"

    def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
        try:
            if "http" in question or "Wikipedia" in question:
                web_context = self.search(question)
                prompt = f"{SYSTEM_PROMPT}\n\n{web_context}\n\nQuestion: {question}"
            elif files:
                file_keys = list(files.keys())
                for key in file_keys:
                    if key.endswith(".mp3"):
                        audio_txt = self.transcribe_audio(files[key])
                        prompt = f"{SYSTEM_PROMPT}\n\n{audio_txt}\n\n{question}"
                        break
                    elif key.endswith(".xlsx"):
                        excel_result = self.handle_excel(files[key])
                        return excel_result, excel_result
                    else:
                        prompt = f"{SYSTEM_PROMPT}\n\n{question}"
            else:
                prompt = f"{SYSTEM_PROMPT}\n\n{question}"

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False,
                temperature=0.0,
                pad_token_id=self.tokenizer.pad_token_id
            )
            output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            final = output_text.strip()
            return final, output_text
        except Exception as e:
            return "ERROR", f"Agent failed: {e}"