|
import os |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
from tools.tavily_search import search_tavily |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id: str = "google/gemma-2b-it"): |
|
|
|
try: |
|
print(f"Laddar tokenizer för {model_id}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN")) |
|
print(f"Laddar modell för {model_id}...") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Använder enhet: {device}") |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
print("Modell laddad framgångsrikt.") |
|
|
|
|
|
self.text_generator = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
|
|
) |
|
print("Textgenereringspipeline skapad.") |
|
|
|
except Exception as e: |
|
print(f"Fel vid initiering av agent: {e}") |
|
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}") |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
""" |
|
Denna metod gör att en instans av GaiaAgent kan kallas som en funktion. |
|
Den kommer att anropa din process_task metod för att generera svaret. |
|
""" |
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
result = self.process_task(question) |
|
print(f"Agent returning answer: {result[:100]}...") |
|
return result |
|
|
|
|
|
def process_task(self, task_description: str) -> str: |
|
|
|
|
|
|
|
prompt = f""" |
|
You are a helpful and expert AI assistant with access to a search tool. |
|
Your task is to carefully and accurately answer questions by using the search tool when necessary. |
|
Always provide a complete and correct answer based on the information you find. |
|
|
|
Your available tools: |
|
1. search_tavily(query: str): Searches on Tavily and returns relevant results. |
|
Use this tool to find information on the internet that you don't know or need to verify. |
|
|
|
To use a tool, write it in the following exact format: |
|
<TOOL_CODE> |
|
tool_name("your search query") |
|
</TOOL_CODE> |
|
Example: |
|
If you need to know the capital of France: |
|
<TOOL_CODE> |
|
search_tavily("capital of France") |
|
</TOOL_CODE> |
|
|
|
When you have found all the necessary information and are ready to answer the task, provide your final answer. |
|
|
|
Task: {task_description} |
|
""" |
|
|
|
max_iterations = 3 |
|
current_response = "" |
|
|
|
for i in range(max_iterations): |
|
full_prompt = prompt + current_response + "\n\nWhat is the next step or your final answer?" |
|
|
|
print(f"[{i+1}/{max_iterations}] Generating response with prompt length: {len(full_prompt)}") |
|
|
|
generated_text = self.text_generator( |
|
full_prompt, |
|
max_new_tokens=1024, |
|
num_return_sequences=1, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
do_sample=True, |
|
top_k=50, top_p=0.95, |
|
temperature=0.8 |
|
)[0]['generated_text'] |
|
|
|
new_content = generated_text[len(full_prompt):].strip() |
|
print(f"DEBUG - Full generated_text: \n---START---\n{generated_text}\n---END---") |
|
print(f"DEBUG - Extracted new_content: '{new_content}'") |
|
|
|
if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content: |
|
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>") |
|
end_index = new_content.find("</TOOL_CODE>") |
|
tool_call_str = new_content[start_index:end_index].strip() |
|
|
|
print(f"Tool call detected: {tool_call_str}") |
|
|
|
try: |
|
if tool_call_str.startswith("search_tavily("): |
|
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'") |
|
tool_output = search_tavily(query) |
|
print(f"Tool result: {tool_output[:200]}...") |
|
current_response += f"\n\nTool Result from {tool_call_str}:\n{tool_output}\n" |
|
else: |
|
tool_output = f"Unknown tool: {tool_call_str}" |
|
print(f"Error: {tool_output}") |
|
current_response += f"\n\n{tool_output}\n" |
|
except Exception as tool_e: |
|
tool_output = f"Error running tool {tool_call_str}: {tool_e}" |
|
print(f"Error: {tool_output}") |
|
current_response += f"\n\n{tool_output}\n" |
|
else: |
|
final_answer = new_content |
|
print(f"Final answer from model:\n{final_answer}") |
|
return final_answer.strip() |
|
|
|
return "Agent could not complete the task within the allowed iterations. Latest response: " + new_content.strip() |
|
|