File size: 3,101 Bytes
9666d9f
1d782b2
 
 
 
 
 
 
 
230477c
2477b72
230477c
 
 
2b8dbad
9666d9f
2b8dbad
9666d9f
 
1d782b2
 
 
 
 
 
 
 
 
2477b72
1d782b2
58c4724
1d782b2
230477c
1d782b2
 
2477b72
 
1d782b2
 
 
 
2477b72
 
 
 
 
 
 
1d782b2
 
 
2477b72
 
 
 
 
1d782b2
2477b72
 
 
 
 
9bf47dc
2477b72
2b8dbad
9666d9f
 
2b8dbad
 
1d782b2
2b8dbad
9666d9f
 
2b8dbad
2477b72
230477c
 
2b8dbad
b5d03d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from ddgs import DDGS
import re
import pandas as pd
import tempfile
import os
import whisper

SYSTEM_PROMPT = """
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
"""

class GaiaAgent:
    def __init__(self, model_id="google/flan-t5-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.transcriber = whisper.load_model("base")

    def search(self, query: str) -> str:
        try:
            with DDGS() as ddgs:
                results = list(ddgs.text(query, safesearch="off"))
                if results:
                    return results[0]['body']
        except Exception as e:
            return ""
        return ""

    def transcribe_audio(self, file_path: str) -> str:
        try:
            result = self.transcriber.transcribe(file_path)
            return result['text']
        except Exception:
            return ""

    def handle_excel(self, file_path: str) -> str:
        try:
            df = pd.read_excel(file_path)
            df.columns = [col.lower() for col in df.columns]
            if 'category' in df.columns and 'sales' in df.columns:
                food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
                return f"{food_sales:.2f}"
        except Exception:
            return ""
        return ""

    def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
        try:
            context = ""
            if files:
                for filename, filepath in files.items():
                    if filename.endswith(".mp3"):
                        context = self.transcribe_audio(filepath)
                        break
                    elif filename.endswith(".xlsx"):
                        excel_result = self.handle_excel(filepath)
                        return excel_result.strip(), excel_result.strip()
            elif "http" in question.lower() or "wikipedia" in question.lower():
                context = self.search(question)

            prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False,
                temperature=0.0,
                pad_token_id=self.tokenizer.pad_token_id
            )
            output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            final = output_text.strip()
            return final, final
        except Exception as e:
            return "ERROR", f"Agent failed: {e}"