File size: 5,988 Bytes
13755f8
21e96d1
 
c3c803a
21e96d1
 
bf58062
21e96d1
 
 
15b9880
21e96d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bfc54d
21e96d1
 
 
15b9880
21e96d1
 
 
b6da6a3
 
 
 
 
 
 
 
 
 
 
 
21e96d1
b44b8a4
 
21e96d1
 
b44b8a4
 
 
21e96d1
b44b8a4
 
 
21e96d1
b44b8a4
21e96d1
b44b8a4
21e96d1
b44b8a4
 
21e96d1
b44b8a4
21e96d1
 
b44b8a4
21e96d1
b44b8a4
21e96d1
 
b44b8a4
21e96d1
 
 
b44b8a4
21e96d1
b44b8a4
21e96d1
 
 
b44b8a4
21e96d1
b44b8a4
 
 
 
21e96d1
 
 
b44b8a4
 
21e96d1
 
 
 
 
 
b44b8a4
21e96d1
 
 
 
 
b44b8a4
 
21e96d1
b44b8a4
 
21e96d1
 
b44b8a4
 
21e96d1
 
 
b44b8a4
21e96d1
 
b44b8a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch # För att kontrollera enheter

# Importera ditt nya sökverktyg
from tools.tavily_search import search_tavily

class GaiaAgent:
    def __init__(self, model_id: str = "google/gemma-2b-it"):
        # Ladda tokenizer och modell manuellt. Detta ger mer kontroll.
        try:
            print(f"Laddar tokenizer för {model_id}...")
            self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN"))
            print(f"Laddar modell för {model_id}...")
            
            # Kontrollera om GPU är tillgänglig
            device = "cuda" if torch.cuda.is_available() else "cpu"
            print(f"Använder enhet: {device}")

            self.model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne
                device_map="auto", # Accelerate hanterar detta över CPU/GPU
                token=os.getenv("HF_TOKEN")
            )
            print("Modell laddad framgångsrikt.")

            # Skapa en pipeline för textgenerering
            self.text_generator = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                # device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU
            )
            print("Textgenereringspipeline skapad.")

        except Exception as e:
            print(f"Fel vid initiering av agent: {e}")
            raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}")

    # --- THIS IS THE MISSING __CALL__ METHOD ---
    def __call__(self, question: str) -> str:
        """
        Denna metod gör att en instans av GaiaAgent kan kallas som en funktion.
        Den kommer att anropa din process_task metod för att generera svaret.
        """
        print(f"Agent received question (first 50 chars): {question[:50]}...")
        result = self.process_task(question)
        print(f"Agent returning answer: {result[:100]}...") # För att inte fylla loggarna med för långa svar
        return result
    # --- END OF MISSING METHOD ---
    
    def process_task(self, task_description: str) -> str:
        # Instruction to the LLM to perform the task and use tools.
        # We need to build a prompt that instructs the model to use tools.
        
        prompt = f"""
        You are a helpful and expert AI assistant with access to a search tool.
        Your task is to carefully and accurately answer questions by using the search tool when necessary.
        Always provide a complete and correct answer based on the information you find.

        Your available tools:
        1. search_tavily(query: str): Searches on Tavily and returns relevant results.
           Use this tool to find information on the internet that you don't know or need to verify.

        To use a tool, write it in the following exact format:
        <TOOL_CODE>
        tool_name("your search query")
        </TOOL_CODE>
        Example:
        If you need to know the capital of France:
        <TOOL_CODE>
        search_tavily("capital of France")
        </TOOL_CODE>

        When you have found all the necessary information and are ready to answer the task, provide your final answer.
        
        Task: {task_description}
        """

        max_iterations = 3
        current_response = ""

        for i in range(max_iterations):
            full_prompt = prompt + current_response + "\n\nWhat is the next step or your final answer?"
            
            print(f"[{i+1}/{max_iterations}] Generating response with prompt length: {len(full_prompt)}")
            
            generated_text = self.text_generator(
                full_prompt, 
                max_new_tokens=1024, # Behold 1024 eller öka om behövs
                num_return_sequences=1,
                pad_token_id=self.tokenizer.eos_token_id,
                do_sample=True,
                top_k=50, top_p=0.95,
                temperature=0.8 # Behold 0.8 eller justera vid behov
            )[0]['generated_text']

            new_content = generated_text[len(full_prompt):].strip()
            print(f"DEBUG - Full generated_text: \n---START---\n{generated_text}\n---END---")
            print(f"DEBUG - Extracted new_content: '{new_content}'")

            if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content:
                start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>")
                end_index = new_content.find("</TOOL_CODE>")
                tool_call_str = new_content[start_index:end_index].strip()
                
                print(f"Tool call detected: {tool_call_str}")
                
                try:
                    if tool_call_str.startswith("search_tavily("):
                        query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'")
                        tool_output = search_tavily(query)
                        print(f"Tool result: {tool_output[:200]}...")
                        current_response += f"\n\nTool Result from {tool_call_str}:\n{tool_output}\n"
                    else:
                        tool_output = f"Unknown tool: {tool_call_str}"
                        print(f"Error: {tool_output}")
                        current_response += f"\n\n{tool_output}\n"
                except Exception as tool_e:
                    tool_output = f"Error running tool {tool_call_str}: {tool_e}"
                    print(f"Error: {tool_output}")
                    current_response += f"\n\n{tool_output}\n"
            else:
                final_answer = new_content
                print(f"Final answer from model:\n{final_answer}")
                return final_answer.strip()

        return "Agent could not complete the task within the allowed iterations. Latest response: " + new_content.strip()