Spaces:
Sleeping
Sleeping
File size: 6,867 Bytes
13755f8 21e96d1 c3c803a 21e96d1 bf58062 21e96d1 15b9880 21e96d1 8bfc54d 21e96d1 15b9880 21e96d1 b6da6a3 21e96d1 c3ff8d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch # För att kontrollera enheter
# Importera ditt nya sökverktyg
from tools.tavily_search import search_tavily
class GaiaAgent:
def __init__(self, model_id: str = "google/gemma-2b-it"):
# Ladda tokenizer och modell manuellt. Detta ger mer kontroll.
try:
print(f"Laddar tokenizer för {model_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN"))
print(f"Laddar modell för {model_id}...")
# Kontrollera om GPU är tillgänglig
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Använder enhet: {device}")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne
device_map="auto", # Accelerate hanterar detta över CPU/GPU
token=os.getenv("HF_TOKEN")
)
print("Modell laddad framgångsrikt.")
# Skapa en pipeline för textgenerering
self.text_generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
# device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU
)
print("Textgenereringspipeline skapad.")
except Exception as e:
print(f"Fel vid initiering av agent: {e}")
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}")
# --- THIS IS THE MISSING __CALL__ METHOD ---
def __call__(self, question: str) -> str:
"""
Denna metod gör att en instans av GaiaAgent kan kallas som en funktion.
Den kommer att anropa din process_task metod för att generera svaret.
"""
print(f"Agent received question (first 50 chars): {question[:50]}...")
result = self.process_task(question)
print(f"Agent returning answer: {result[:100]}...") # För att inte fylla loggarna med för långa svar
return result
# --- END OF MISSING METHOD ---
def process_task(self, task_description: str) -> str:
# Enkel instruktion till LLM för att utföra uppgiften
# Vi måste bygga en prompt som instruerar modellen att använda verktyg.
# Instruktioner till modellen för att svara och använda verktyg
prompt = f"""
Du är en expertagent med tillgång till ett sökverktyg.
Använd alltid sökverktyget om du behöver information som inte finns i din träningsdata eller om du behöver validera fakta.
Försök alltid att svara på uppgiften heltäckande.
Dina tillgängliga verktyg:
1. search_tavily(query: str): Söker på Tavily och returnerar relevanta resultat.
För att använda ett verktyg, skriv det på följande format:
<TOOL_CODE>
verktygsnamn("fråga till verktyget")
</TOOL_CODE>
Exempel:
För att söka efter information om Mars:
<TOOL_CODE>
search_tavily("information om Mars")
</TOOL_CODE>
När du har hittat all nödvändig information och är redo att svara, skriv ditt slutgiltiga svar.
Uppgift: {task_description}
"""
max_iterations = 3 # Begränsa iterationer för att undvika oändliga loopar
current_response = ""
for i in range(max_iterations):
# Skapa prompten för den aktuella iterationen
full_prompt = prompt + current_response + "\n\nVad är nästa steg eller ditt slutgiltiga svar?"
print(f"[{i+1}/{max_iterations}] Genererar svar med promptlängd: {len(full_prompt)}")
# Generera svar från modellen
# max_new_tokens är viktig för att styra svarets längd
generated_text = self.text_generator(
full_prompt,
max_new_tokens=500, # Justera vid behov
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id, # Viktigt för T5/Gemma
do_sample=True, # Aktivera sampling för mer variation
top_k=50, top_p=0.95, # Typiska samplingparametrar
temperature=0.7 # Kontrollera kreativitet
)[0]['generated_text']
# Extrahera endast den nya delen av texten (modellen genererar hela prompten + nytt svar)
new_content = generated_text[len(full_prompt):].strip()
print(f"Modellgenerering: {new_content}")
# Kontrollera om modellen vill använda ett verktyg
if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content:
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>")
end_index = new_content.find("</TOOL_CODE>")
tool_call_str = new_content[start_index:end_index].strip()
print(f"Verktygskall upptäckt: {tool_call_str}")
try:
# EVAL är farligt i verkliga applikationer, men för GAIA och detta specifika verktyg är det OK.
# Säkerställ att endast godkända funktioner kan kallas.
if tool_call_str.startswith("search_tavily("):
# Extrahera argumenten till funktionen
# En mer robust parser skulle behövas för mer komplexa verktyg
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'")
tool_output = search_tavily(query)
print(f"Verktygsresultat: {tool_output[:200]}...") # Printa kortfattat
current_response += f"\n\nVerktygsresultat från {tool_call_str}:\n{tool_output}\n"
else:
tool_output = f"Okänt verktyg: {tool_call_str}"
print(f"Fel: {tool_output}")
current_response += f"\n\n{tool_output}\n"
except Exception as tool_e:
tool_output = f"Fel vid körning av verktyg {tool_call_str}: {tool_e}"
print(f"Fel: {tool_output}")
current_response += f"\n\n{tool_output}\n"
else:
# Modellen har genererat ett svar utan att kalla verktyg
final_answer = new_content
print(f"Slutgiltigt svar från modellen:\n{final_answer}")
return final_answer.strip()
# Om max_iterations nås utan slutgiltigt svar
return "Agenten kunde inte slutföra uppgiften inom tillåtet antal iterationer. Senaste svar: " + new_content.strip()
|