Spaces:
Sleeping
Sleeping
import os | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch # För att kontrollera enheter | |
# Importera ditt nya sökverktyg | |
from tools.tavily_search import search_tavily | |
class GaiaAgent: | |
def __init__(self, model_id: str = "google/gemma-2b-it"): | |
# Ladda tokenizer och modell manuellt. Detta ger mer kontroll. | |
try: | |
print(f"Laddar tokenizer för {model_id}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN")) | |
print(f"Laddar modell för {model_id}...") | |
# Kontrollera om GPU är tillgänglig | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Använder enhet: {device}") | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne | |
device_map="auto", # Accelerate hanterar detta över CPU/GPU | |
token=os.getenv("HF_TOKEN") | |
) | |
print("Modell laddad framgångsrikt.") | |
# Skapa en pipeline för textgenerering | |
self.text_generator = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
# device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU | |
) | |
print("Textgenereringspipeline skapad.") | |
except Exception as e: | |
print(f"Fel vid initiering av agent: {e}") | |
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}") | |
def process_task(self, task_description: str) -> str: | |
# Enkel instruktion till LLM för att utföra uppgiften | |
# Vi måste bygga en prompt som instruerar modellen att använda verktyg. | |
# Instruktioner till modellen för att svara och använda verktyg | |
prompt = f""" | |
Du är en expertagent med tillgång till ett sökverktyg. | |
Använd alltid sökverktyget om du behöver information som inte finns i din träningsdata eller om du behöver validera fakta. | |
Försök alltid att svara på uppgiften heltäckande. | |
Dina tillgängliga verktyg: | |
1. search_tavily(query: str): Söker på Tavily och returnerar relevanta resultat. | |
För att använda ett verktyg, skriv det på följande format: | |
<TOOL_CODE> | |
verktygsnamn("fråga till verktyget") | |
</TOOL_CODE> | |
Exempel: | |
För att söka efter information om Mars: | |
<TOOL_CODE> | |
search_tavily("information om Mars") | |
</TOOL_CODE> | |
När du har hittat all nödvändig information och är redo att svara, skriv ditt slutgiltiga svar. | |
Uppgift: {task_description} | |
""" | |
max_iterations = 3 # Begränsa iterationer för att undvika oändliga loopar | |
current_response = "" | |
for i in range(max_iterations): | |
# Skapa prompten för den aktuella iterationen | |
full_prompt = prompt + current_response + "\n\nVad är nästa steg eller ditt slutgiltiga svar?" | |
print(f"[{i+1}/{max_iterations}] Genererar svar med promptlängd: {len(full_prompt)}") | |
# Generera svar från modellen | |
# max_new_tokens är viktig för att styra svarets längd | |
generated_text = self.text_generator( | |
full_prompt, | |
max_new_tokens=500, # Justera vid behov | |
num_return_sequences=1, | |
pad_token_id=self.tokenizer.eos_token_id, # Viktigt för T5/Gemma | |
do_sample=True, # Aktivera sampling för mer variation | |
top_k=50, top_p=0.95, # Typiska samplingparametrar | |
temperature=0.7 # Kontrollera kreativitet | |
)[0]['generated_text'] | |
# Extrahera endast den nya delen av texten (modellen genererar hela prompten + nytt svar) | |
new_content = generated_text[len(full_prompt):].strip() | |
print(f"Modellgenerering: {new_content}") | |
# Kontrollera om modellen vill använda ett verktyg | |
if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content: | |
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>") | |
end_index = new_content.find("</TOOL_CODE>") | |
tool_call_str = new_content[start_index:end_index].strip() | |
print(f"Verktygskall upptäckt: {tool_call_str}") | |
try: | |
# EVAL är farligt i verkliga applikationer, men för GAIA och detta specifika verktyg är det OK. | |
# Säkerställ att endast godkända funktioner kan kallas. | |
if tool_call_str.startswith("search_tavily("): | |
# Extrahera argumenten till funktionen | |
# En mer robust parser skulle behövas för mer komplexa verktyg | |
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'") | |
tool_output = search_tavily(query) | |
print(f"Verktygsresultat: {tool_output[:200]}...") # Printa kortfattat | |
current_response += f"\n\nVerktygsresultat från {tool_call_str}:\n{tool_output}\n" | |
else: | |
tool_output = f"Okänt verktyg: {tool_call_str}" | |
print(f"Fel: {tool_output}") | |
current_response += f"\n\n{tool_output}\n" | |
except Exception as tool_e: | |
tool_output = f"Fel vid körning av verktyg {tool_call_str}: {tool_e}" | |
print(f"Fel: {tool_output}") | |
current_response += f"\n\n{tool_output}\n" | |
else: | |
# Modellen har genererat ett svar utan att kalla verktyg | |
final_answer = new_content | |
print(f"Slutgiltigt svar från modellen:\n{final_answer}") | |
return final_answer.strip() | |
# Om max_iterations nås utan slutgiltigt svar | |
return "Agenten kunde inte slutföra uppgiften inom tillåtet antal iterationer. Senaste svar: " + new_content.strip() | |