File size: 7,861 Bytes
13755f8
15b9880
 
 
bf58062
 
c3c803a
 
13755f8
bf58062
c3c803a
 
 
13755f8
 
 
bf58062
 
 
 
 
 
 
13755f8
5fe193c
bf58062
 
 
aef8c2b
bf58062
 
13755f8
bf58062
15b9880
 
 
 
 
 
 
 
 
5617dda
15b9880
 
 
bf58062
15b9880
 
 
 
bf58062
15b9880
bf58062
15b9880
 
c36be6c
15b9880
 
 
bf58062
15b9880
 
 
 
 
 
 
bf58062
 
 
15b9880
 
 
 
 
 
bf58062
 
 
15b9880
bf58062
15b9880
 
 
bf58062
 
 
 
 
 
 
 
 
 
 
15b9880
 
 
bf58062
15b9880
c36be6c
15b9880
 
 
bf58062
 
15b9880
 
 
 
bf58062
 
 
 
15b9880
bf58062
 
 
 
 
 
 
 
 
 
15b9880
 
 
 
 
13755f8
15b9880
 
 
bf58062
15b9880
 
bf58062
15b9880
 
 
 
 
bf58062
15b9880
 
bf58062
15b9880
 
ee62c26
bf58062
058b8cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import json
import re
from typing import Tuple, Dict, Any
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM # Import AutoTokenizer and AutoModelForSeq2SeqLM

from tools.asr_tool import transcribe_audio
from tools.excel_tool import analyze_excel
from tools.search_tool import search_duckduckgo
from tools.math_tool import calculate_math # Make sure to import your math tool

class GaiaAgent:
    def __init__(self):
        token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
        if not token:
            raise ValueError("Missing HUGGINGFACEHUB_API_TOKEN environment variable.")

        # Specify the model and load tokenizer and model separately for better control
        model_name = "google/flan-t5-large"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=token)

        # Use the pipeline with the loaded model and tokenizer
        self.llm = pipeline(
            "text2text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            device="cpu", # Consider "cuda" if you have a GPU
            max_new_tokens=256,
            do_sample=False, # Set to True if you want to use temperature and top_p/k
            # temperature=0.1, # Removed, as it's not a valid pipeline initialization flag here
        )

        self.system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""

    def extract_final_answer(self, text: str) -> str:
        """Extrahera det slutliga svaret från modellens output"""
        final_answer_match = re.search(r'FINAL ANSWER:\s*(.+?)(?:\n|$)', text, re.IGNORECASE)
        if final_answer_match:
            return final_answer_match.group(1).strip()
        sentences = text.strip().split('\n')
        return sentences[-1].strip() if sentences else text.strip()

    def needs_tool(self, question: str) -> Tuple[str, bool]:
        """Bestäm vilket verktyg som behövs baserat på frågan"""
        question_lower = question.lower()

        if any(ext in question_lower for ext in ['.mp3', '.wav', '.m4a', '.flac']):
            return 'audio', True
        if any(ext in question_lower for ext in ['.xlsx', '.xls', '.csv']):
            return 'excel', True
        if any(keyword in question_lower for keyword in ['search', 'find', 'lookup', 'http', 'www.', 'wikipedia', 'albums', 'discography', 'published', 'website']):
            return 'search', True
        if any(keyword in question_lower for keyword in ['calculate', 'compute', 'sum', 'average', 'count', 'what is', 'solve']):
            return 'math', True
        return 'llm', False

    def process_with_tools(self, question: str, tool_type: str) -> Tuple[str, str]:
        """Bearbeta frågan med specifika verktyg"""
        trace_log = f"Detected {tool_type} task. Processing...\n"

        try:
            if tool_type == 'audio':
                audio_files = re.findall(r'\b[\w\-_]+\.(mp3|wav|m4a|flac)\b', question, re.IGNORECASE)
                if audio_files:
                    result = transcribe_audio(audio_files[0])
                    trace_log += f"Audio transcription: {result}\n"
                    return result, trace_log
                else:
                    return "No audio file mentioned in the question.", trace_log

            elif tool_type == 'excel':
                excel_files = re.findall(r'\b[\w\-_]+\.(xlsx|xls|csv)\b', question, re.IGNORECASE)
                if excel_files:
                    result = analyze_excel(excel_files[0])
                    trace_log += f"Excel analysis: {result}\n"
                    return result, trace_log
                else:
                    return "No Excel file mentioned in the question.", trace_log

            elif tool_type == 'search':
                search_query = question # This might need refinement to extract just the search query
                result = search_duckduckgo(search_query)
                trace_log += f"Search results: {result}\n"
                return result, trace_log

            elif tool_type == 'math':
                math_expression_match = re.search(r'calculate (.+)', question, re.IGNORECASE)
                if math_expression_match:
                    expression = math_expression_match.group(1).strip()
                    result = calculate_math(expression)
                    trace_log += f"Math calculation: {result}\n"
                    return result, trace_log
                else:
                    return "No clear mathematical expression found in the question.", trace_log

        except Exception as e:
            trace_log += f"Error using {tool_type} tool: {str(e)}\n"
            return f"Error: {str(e)}", trace_log

        return "No valid input found for tool", trace_log

    def reason_with_llm(self, question: str, context: str = "") -> Tuple[str, str]:
        """Använd LLM för reasoning med kontext"""
        trace_log = "Using LLM for reasoning...\n"

        # Combine system prompt, context, and question, ensuring it fits token limit
        if context:
            prompt = f"{self.system_prompt}\n\nContext: {context}\n\nQuestion: {question}\n\nPlease analyze this step by step and provide your final answer."
        else:
            prompt = f"{self.system_prompt}\n\nQuestion: {question}\n\nPlease analyze this step by step and provide your final answer."

        # Tokenize and truncate if necessary
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.tokenizer.model_max_length)

        try:
            # Generate response using the model's generate method for more control
            # You can add generation arguments here, e.g., temperature, top_k, etc.
            outputs = self.model.generate(
                inputs.input_ids,
                max_new_tokens=256,
                do_sample=False, # Set to True to enable temperature and other sampling parameters
                # temperature=0.1, # Example: Only if do_sample is True
            )
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            trace_log += f"LLM response: {response}\n"
            return response, trace_log
        except Exception as e:
            trace_log += f"Error with LLM: {str(e)}\n"
            return f"Error: {str(e)}", trace_log

    def __call__(self, question: str) -> Tuple[str, str]:
        """Huvudfunktion som bearbetar frågan"""
        total_trace = f"Processing question: {question}\n"

        tool_type, needs_tool = self.needs_tool(question)
        total_trace += f"Tool needed: {tool_type}\n"

        context = ""
        if needs_tool and tool_type != 'llm':
            tool_result, tool_trace = self.process_with_tools(question, tool_type)
            total_trace += tool_trace
            context = tool_result

        llm_response, llm_trace = self.reason_with_llm(question, context)
        total_trace += llm_trace

        final_answer = self.extract_final_answer(llm_response)
        total_trace += f"Final answer extracted: {final_answer}\n"

        return final_answer, total_trace