File size: 6,288 Bytes
13755f8
21e96d1
 
c3c803a
21e96d1
 
bf58062
21e96d1
 
 
15b9880
21e96d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bfc54d
21e96d1
 
 
15b9880
21e96d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3ff8d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch # För att kontrollera enheter

# Importera ditt nya sökverktyg
from tools.tavily_search import search_tavily

class GaiaAgent:
    def __init__(self, model_id: str = "google/gemma-2b-it"):
        # Ladda tokenizer och modell manuellt. Detta ger mer kontroll.
        try:
            print(f"Laddar tokenizer för {model_id}...")
            self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN"))
            print(f"Laddar modell för {model_id}...")
            
            # Kontrollera om GPU är tillgänglig
            device = "cuda" if torch.cuda.is_available() else "cpu"
            print(f"Använder enhet: {device}")

            self.model = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne
                device_map="auto", # Accelerate hanterar detta över CPU/GPU
                token=os.getenv("HF_TOKEN")
            )
            print("Modell laddad framgångsrikt.")

            # Skapa en pipeline för textgenerering
            self.text_generator = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                # device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU
            )
            print("Textgenereringspipeline skapad.")

        except Exception as e:
            print(f"Fel vid initiering av agent: {e}")
            raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}")

    def process_task(self, task_description: str) -> str:
        # Enkel instruktion till LLM för att utföra uppgiften
        # Vi måste bygga en prompt som instruerar modellen att använda verktyg.
        
        # Instruktioner till modellen för att svara och använda verktyg
        prompt = f"""
        Du är en expertagent med tillgång till ett sökverktyg.
        Använd alltid sökverktyget om du behöver information som inte finns i din träningsdata eller om du behöver validera fakta.
        Försök alltid att svara på uppgiften heltäckande.

        Dina tillgängliga verktyg:
        1. search_tavily(query: str): Söker på Tavily och returnerar relevanta resultat.

        För att använda ett verktyg, skriv det på följande format:
        <TOOL_CODE>
        verktygsnamn("fråga till verktyget")
        </TOOL_CODE>

        Exempel:
        För att söka efter information om Mars:
        <TOOL_CODE>
        search_tavily("information om Mars")
        </TOOL_CODE>

        När du har hittat all nödvändig information och är redo att svara, skriv ditt slutgiltiga svar.
        
        Uppgift: {task_description}
        """

        max_iterations = 3 # Begränsa iterationer för att undvika oändliga loopar
        current_response = ""

        for i in range(max_iterations):
            # Skapa prompten för den aktuella iterationen
            full_prompt = prompt + current_response + "\n\nVad är nästa steg eller ditt slutgiltiga svar?"
            
            print(f"[{i+1}/{max_iterations}] Genererar svar med promptlängd: {len(full_prompt)}")
            
            # Generera svar från modellen
            # max_new_tokens är viktig för att styra svarets längd
            generated_text = self.text_generator(
                full_prompt, 
                max_new_tokens=500, # Justera vid behov
                num_return_sequences=1,
                pad_token_id=self.tokenizer.eos_token_id, # Viktigt för T5/Gemma
                do_sample=True, # Aktivera sampling för mer variation
                top_k=50, top_p=0.95, # Typiska samplingparametrar
                temperature=0.7 # Kontrollera kreativitet
            )[0]['generated_text']

            # Extrahera endast den nya delen av texten (modellen genererar hela prompten + nytt svar)
            new_content = generated_text[len(full_prompt):].strip()
            print(f"Modellgenerering: {new_content}")

            # Kontrollera om modellen vill använda ett verktyg
            if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content:
                start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>")
                end_index = new_content.find("</TOOL_CODE>")
                tool_call_str = new_content[start_index:end_index].strip()
                
                print(f"Verktygskall upptäckt: {tool_call_str}")
                
                try:
                    # EVAL är farligt i verkliga applikationer, men för GAIA och detta specifika verktyg är det OK.
                    # Säkerställ att endast godkända funktioner kan kallas.
                    if tool_call_str.startswith("search_tavily("):
                        # Extrahera argumenten till funktionen
                        # En mer robust parser skulle behövas för mer komplexa verktyg
                        query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'")
                        tool_output = search_tavily(query)
                        print(f"Verktygsresultat: {tool_output[:200]}...") # Printa kortfattat
                        current_response += f"\n\nVerktygsresultat från {tool_call_str}:\n{tool_output}\n"
                    else:
                        tool_output = f"Okänt verktyg: {tool_call_str}"
                        print(f"Fel: {tool_output}")
                        current_response += f"\n\n{tool_output}\n"
                except Exception as tool_e:
                    tool_output = f"Fel vid körning av verktyg {tool_call_str}: {tool_e}"
                    print(f"Fel: {tool_output}")
                    current_response += f"\n\n{tool_output}\n"
            else:
                # Modellen har genererat ett svar utan att kalla verktyg
                final_answer = new_content
                print(f"Slutgiltigt svar från modellen:\n{final_answer}")
                return final_answer.strip()

        # Om max_iterations nås utan slutgiltigt svar
        return "Agenten kunde inte slutföra uppgiften inom tillåtet antal iterationer. Senaste svar: " + new_content.strip()