Freddolin commited on
Commit
21e96d1
·
verified ·
1 Parent(s): 9741645

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +123 -20
agent.py CHANGED
@@ -1,26 +1,129 @@
1
  import os
2
- from smolagents import CodeAgent, DuckDuckGoSearchTool
3
- from smolagents import TransformersModel
4
 
5
- class GaiaAgent:
6
- def __init__(self, model_id: str = "google/gemma-2b-it"): # <-- CHANGE MODEL HERE
7
- self.llm_model = TransformersModel(
8
- model_id=model_id,
9
- task="text-generation",
10
- # device_map="auto" # Can keep this, accelerate will manage
11
- )
12
-
13
- self.agent = CodeAgent(
14
- model=self.llm_model,
15
- tools=[DuckDuckGoSearchTool()],
16
- add_base_tools=False,
17
- verbose=True
18
- )
19
 
20
- def process_task(self, task_description: str) -> str:
 
 
21
  try:
22
- response = self.agent.run(task_description)
23
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  except Exception as e:
25
- return f"An error occurred during agent processing: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
1
  import os
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
+ import torch # För att kontrollera enheter
4
 
5
+ # Importera ditt nya sökverktyg
6
+ from tools.tavily_search import search_tavily
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ class GaiaAgent:
9
+ def __init__(self, model_id: str = "google/gemma-2b-it"):
10
+ # Ladda tokenizer och modell manuellt. Detta ger mer kontroll.
11
  try:
12
+ print(f"Laddar tokenizer för {model_id}...")
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN"))
14
+ print(f"Laddar modell för {model_id}...")
15
+
16
+ # Kontrollera om GPU är tillgänglig
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print(f"Använder enhet: {device}")
19
+
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ model_id,
22
+ torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne
23
+ device_map="auto", # Accelerate hanterar detta över CPU/GPU
24
+ token=os.getenv("HF_TOKEN")
25
+ )
26
+ print("Modell laddad framgångsrikt.")
27
+
28
+ # Skapa en pipeline för textgenerering
29
+ self.text_generator = pipeline(
30
+ "text-generation",
31
+ model=self.model,
32
+ tokenizer=self.tokenizer,
33
+ device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU
34
+ )
35
+ print("Textgenereringspipeline skapad.")
36
+
37
  except Exception as e:
38
+ print(f"Fel vid initiering av agent: {e}")
39
+ raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}")
40
+
41
+ def process_task(self, task_description: str) -> str:
42
+ # Enkel instruktion till LLM för att utföra uppgiften
43
+ # Vi måste bygga en prompt som instruerar modellen att använda verktyg.
44
+
45
+ # Instruktioner till modellen för att svara och använda verktyg
46
+ prompt = f"""
47
+ Du är en expertagent med tillgång till ett sökverktyg.
48
+ Använd alltid sökverktyget om du behöver information som inte finns i din träningsdata eller om du behöver validera fakta.
49
+ Försök alltid att svara på uppgiften heltäckande.
50
+
51
+ Dina tillgängliga verktyg:
52
+ 1. search_tavily(query: str): Söker på Tavily och returnerar relevanta resultat.
53
+
54
+ För att använda ett verktyg, skriv det på följande format:
55
+ <TOOL_CODE>
56
+ verktygsnamn("fråga till verktyget")
57
+ </TOOL_CODE>
58
+
59
+ Exempel:
60
+ För att söka efter information om Mars:
61
+ <TOOL_CODE>
62
+ search_tavily("information om Mars")
63
+ </TOOL_CODE>
64
+
65
+ När du har hittat all nödvändig information och är redo att svara, skriv ditt slutgiltiga svar.
66
+
67
+ Uppgift: {task_description}
68
+ """
69
+
70
+ max_iterations = 3 # Begränsa iterationer för att undvika oändliga loopar
71
+ current_response = ""
72
+
73
+ for i in range(max_iterations):
74
+ # Skapa prompten för den aktuella iterationen
75
+ full_prompt = prompt + current_response + "\n\nVad är nästa steg eller ditt slutgiltiga svar?"
76
+
77
+ print(f"[{i+1}/{max_iterations}] Genererar svar med promptlängd: {len(full_prompt)}")
78
+
79
+ # Generera svar från modellen
80
+ # max_new_tokens är viktig för att styra svarets längd
81
+ generated_text = self.text_generator(
82
+ full_prompt,
83
+ max_new_tokens=500, # Justera vid behov
84
+ num_return_sequences=1,
85
+ pad_token_id=self.tokenizer.eos_token_id, # Viktigt för T5/Gemma
86
+ do_sample=True, # Aktivera sampling för mer variation
87
+ top_k=50, top_p=0.95, # Typiska samplingparametrar
88
+ temperature=0.7 # Kontrollera kreativitet
89
+ )[0]['generated_text']
90
+
91
+ # Extrahera endast den nya delen av texten (modellen genererar hela prompten + nytt svar)
92
+ new_content = generated_text[len(full_prompt):].strip()
93
+ print(f"Modellgenerering: {new_content}")
94
+
95
+ # Kontrollera om modellen vill använda ett verktyg
96
+ if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content:
97
+ start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>")
98
+ end_index = new_content.find("</TOOL_CODE>")
99
+ tool_call_str = new_content[start_index:end_index].strip()
100
+
101
+ print(f"Verktygskall upptäckt: {tool_call_str}")
102
+
103
+ try:
104
+ # EVAL är farligt i verkliga applikationer, men för GAIA och detta specifika verktyg är det OK.
105
+ # Säkerställ att endast godkända funktioner kan kallas.
106
+ if tool_call_str.startswith("search_tavily("):
107
+ # Extrahera argumenten till funktionen
108
+ # En mer robust parser skulle behövas för mer komplexa verktyg
109
+ query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'")
110
+ tool_output = search_tavily(query)
111
+ print(f"Verktygsresultat: {tool_output[:200]}...") # Printa kortfattat
112
+ current_response += f"\n\nVerktygsresultat från {tool_call_str}:\n{tool_output}\n"
113
+ else:
114
+ tool_output = f"Okänt verktyg: {tool_call_str}"
115
+ print(f"Fel: {tool_output}")
116
+ current_response += f"\n\n{tool_output}\n"
117
+ except Exception as tool_e:
118
+ tool_output = f"Fel vid körning av verktyg {tool_call_str}: {tool_e}"
119
+ print(f"Fel: {tool_output}")
120
+ current_response += f"\n\n{tool_output}\n"
121
+ else:
122
+ # Modellen har genererat ett svar utan att kalla verktyg
123
+ final_answer = new_content
124
+ print(f"Slutgiltigt svar från modellen:\n{final_answer}")
125
+ return final_answer.strip()
126
+
127
+ # Om max_iterations nås utan slutgiltigt svar
128
+ return "Agenten kunde inte slutföra uppgiften inom tillåtet antal iterationer. Senaste svar: " + new_content.strip()
129