|
import os |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
from tools.tavily_search import search_tavily |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id: str = "google/gemma-2b-it"): |
|
|
|
try: |
|
print(f"Laddar tokenizer för {model_id}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN")) |
|
print(f"Laddar modell för {model_id}...") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Använder enhet: {device}") |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
print("Modell laddad framgångsrikt.") |
|
|
|
|
|
self.text_generator = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
|
|
) |
|
print("Textgenereringspipeline skapad.") |
|
|
|
except Exception as e: |
|
print(f"Fel vid initiering av agent: {e}") |
|
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}") |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
""" |
|
Denna metod gör att en instans av GaiaAgent kan kallas som en funktion. |
|
Den kommer att anropa din process_task metod för att generera svaret. |
|
""" |
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
result = self.process_task(question) |
|
print(f"Agent returning answer: {result[:100]}...") |
|
return result |
|
|
|
|
|
def process_task(self, task_description: str) -> str: |
|
|
|
|
|
|
|
|
|
prompt = f""" |
|
Du är en expertagent med tillgång till ett sökverktyg. |
|
Använd alltid sökverktyget om du behöver information som inte finns i din träningsdata eller om du behöver validera fakta. |
|
Försök alltid att svara på uppgiften heltäckande. |
|
|
|
Dina tillgängliga verktyg: |
|
1. search_tavily(query: str): Söker på Tavily och returnerar relevanta resultat. |
|
|
|
För att använda ett verktyg, skriv det på följande format: |
|
<TOOL_CODE> |
|
verktygsnamn("fråga till verktyget") |
|
</TOOL_CODE> |
|
|
|
Exempel: |
|
För att söka efter information om Mars: |
|
<TOOL_CODE> |
|
search_tavily("information om Mars") |
|
</TOOL_CODE> |
|
|
|
När du har hittat all nödvändig information och är redo att svara, skriv ditt slutgiltiga svar. |
|
|
|
Uppgift: {task_description} |
|
""" |
|
|
|
max_iterations = 3 |
|
current_response = "" |
|
|
|
for i in range(max_iterations): |
|
|
|
full_prompt = prompt + current_response + "\n\nVad är nästa steg eller ditt slutgiltiga svar?" |
|
|
|
print(f"[{i+1}/{max_iterations}] Genererar svar med promptlängd: {len(full_prompt)}") |
|
|
|
|
|
|
|
generated_text = self.text_generator( |
|
full_prompt, |
|
max_new_tokens=1024, |
|
num_return_sequences=1, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
do_sample=True, |
|
top_k=50, top_p=0.95, |
|
temperature=0.8 |
|
)[0]['generated_text'] |
|
|
|
|
|
new_content = generated_text[len(full_prompt):].strip() |
|
print(f"Modellgenerering: {new_content}") |
|
|
|
|
|
if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content: |
|
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>") |
|
end_index = new_content.find("</TOOL_CODE>") |
|
tool_call_str = new_content[start_index:end_index].strip() |
|
|
|
print(f"Verktygskall upptäckt: {tool_call_str}") |
|
|
|
try: |
|
|
|
|
|
if tool_call_str.startswith("search_tavily("): |
|
|
|
|
|
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'") |
|
tool_output = search_tavily(query) |
|
print(f"Verktygsresultat: {tool_output[:200]}...") |
|
current_response += f"\n\nVerktygsresultat från {tool_call_str}:\n{tool_output}\n" |
|
else: |
|
tool_output = f"Okänt verktyg: {tool_call_str}" |
|
print(f"Fel: {tool_output}") |
|
current_response += f"\n\n{tool_output}\n" |
|
except Exception as tool_e: |
|
tool_output = f"Fel vid körning av verktyg {tool_call_str}: {tool_e}" |
|
print(f"Fel: {tool_output}") |
|
current_response += f"\n\n{tool_output}\n" |
|
else: |
|
|
|
final_answer = new_content |
|
print(f"Slutgiltigt svar från modellen:\n{final_answer}") |
|
return final_answer.strip() |
|
|
|
|
|
return "Agenten kunde inte slutföra uppgiften inom tillåtet antal iterationer. Senaste svar: " + new_content.strip() |
|
|