File size: 3,655 Bytes
9666d9f
e258602
 
 
 
 
 
 
7c3b683
1d782b2
 
 
230477c
2477b72
230477c
 
 
2b8dbad
9666d9f
2b8dbad
9666d9f
 
e258602
 
 
 
 
 
 
 
 
 
 
 
 
1d782b2
 
 
 
7c3b683
1d782b2
 
e258602
2477b72
1d782b2
58c4724
1d782b2
230477c
e258602
1d782b2
2477b72
 
1d782b2
 
 
 
2477b72
 
 
 
 
 
 
1d782b2
 
 
2477b72
 
 
e258602
2477b72
1d782b2
2477b72
 
 
 
 
9bf47dc
2477b72
2b8dbad
9666d9f
 
2b8dbad
 
 
9666d9f
 
2b8dbad
2477b72
230477c
 
b5d03d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    pipeline,
    AutoProcessor,
    AutoModelForSpeechSeq2Seq
)
from duckduckgo_search import DDGS
import pandas as pd
import os

SYSTEM_PROMPT = """
You are a helpful AI assistant. Think step by step to solve the problem. If the question requires reasoning, perform it. If it refers to a search or file, use the result provided. At the end, return ONLY the final answer string. No explanations.
"""

class GaiaAgent:
    def __init__(self, model_id="google/flan-t5-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Whisper via HF
        self.asr_model_id = "openai/whisper-small"
        self.asr_processor = AutoProcessor.from_pretrained(self.asr_model_id)
        self.asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(self.asr_model_id).to(self.device)
        self.pipe = pipeline(
            "automatic-speech-recognition",
            model=self.asr_model,
            tokenizer=self.asr_processor.tokenizer,
            feature_extractor=self.asr_processor.feature_extractor,
            return_timestamps=False,
            device=0 if torch.cuda.is_available() else -1
        )

    def search(self, query: str) -> str:
        try:
            with DDGS() as ddgs:
                results = ddgs.text("your query", max_results=1)
                if results:
                    return results[0]['body']
        except Exception:
            return ""
        return ""

    def transcribe_audio(self, file_path: str) -> str:
        try:
            result = self.pipe(file_path)
            return result['text']
        except Exception:
            return ""

    def handle_excel(self, file_path: str) -> str:
        try:
            df = pd.read_excel(file_path)
            df.columns = [col.lower() for col in df.columns]
            if 'category' in df.columns and 'sales' in df.columns:
                food_sales = df[df['category'].str.lower() != 'drink']['sales'].sum()
                return f"{food_sales:.2f}"
        except Exception:
            return ""
        return ""

    def __call__(self, question: str, files: dict = None) -> tuple[str, str]:
        try:
            context = ""
            if files:
                for filename, filepath in files.items():
                    if filename.endswith(".mp3") or filename.endswith(".wav"):
                        context = self.transcribe_audio(filepath)
                        break
                    elif filename.endswith(".xlsx"):
                        excel_result = self.handle_excel(filepath)
                        return excel_result.strip(), excel_result.strip()
            elif "http" in question.lower() or "wikipedia" in question.lower():
                context = self.search(question)

            prompt = f"{SYSTEM_PROMPT}\n\n{context}\n\nQuestion: {question.strip()}"
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(self.device)
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=128,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id
            )
            output_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            final = output_text.strip()
            return final, final
        except Exception as e:
            return "ERROR", f"Agent failed: {e}"