Freddolin's picture
Update agent.py
b44b8a4 verified
raw
history blame
5.99 kB
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch # För att kontrollera enheter
# Importera ditt nya sökverktyg
from tools.tavily_search import search_tavily
class GaiaAgent:
def __init__(self, model_id: str = "google/gemma-2b-it"):
# Ladda tokenizer och modell manuellt. Detta ger mer kontroll.
try:
print(f"Laddar tokenizer för {model_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN"))
print(f"Laddar modell för {model_id}...")
# Kontrollera om GPU är tillgänglig
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Använder enhet: {device}")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne
device_map="auto", # Accelerate hanterar detta över CPU/GPU
token=os.getenv("HF_TOKEN")
)
print("Modell laddad framgångsrikt.")
# Skapa en pipeline för textgenerering
self.text_generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
# device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU
)
print("Textgenereringspipeline skapad.")
except Exception as e:
print(f"Fel vid initiering av agent: {e}")
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}")
# --- THIS IS THE MISSING __CALL__ METHOD ---
def __call__(self, question: str) -> str:
"""
Denna metod gör att en instans av GaiaAgent kan kallas som en funktion.
Den kommer att anropa din process_task metod för att generera svaret.
"""
print(f"Agent received question (first 50 chars): {question[:50]}...")
result = self.process_task(question)
print(f"Agent returning answer: {result[:100]}...") # För att inte fylla loggarna med för långa svar
return result
# --- END OF MISSING METHOD ---
def process_task(self, task_description: str) -> str:
# Instruction to the LLM to perform the task and use tools.
# We need to build a prompt that instructs the model to use tools.
prompt = f"""
You are a helpful and expert AI assistant with access to a search tool.
Your task is to carefully and accurately answer questions by using the search tool when necessary.
Always provide a complete and correct answer based on the information you find.
Your available tools:
1. search_tavily(query: str): Searches on Tavily and returns relevant results.
Use this tool to find information on the internet that you don't know or need to verify.
To use a tool, write it in the following exact format:
<TOOL_CODE>
tool_name("your search query")
</TOOL_CODE>
Example:
If you need to know the capital of France:
<TOOL_CODE>
search_tavily("capital of France")
</TOOL_CODE>
When you have found all the necessary information and are ready to answer the task, provide your final answer.
Task: {task_description}
"""
max_iterations = 3
current_response = ""
for i in range(max_iterations):
full_prompt = prompt + current_response + "\n\nWhat is the next step or your final answer?"
print(f"[{i+1}/{max_iterations}] Generating response with prompt length: {len(full_prompt)}")
generated_text = self.text_generator(
full_prompt,
max_new_tokens=1024, # Behold 1024 eller öka om behövs
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
top_k=50, top_p=0.95,
temperature=0.8 # Behold 0.8 eller justera vid behov
)[0]['generated_text']
new_content = generated_text[len(full_prompt):].strip()
print(f"DEBUG - Full generated_text: \n---START---\n{generated_text}\n---END---")
print(f"DEBUG - Extracted new_content: '{new_content}'")
if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content:
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>")
end_index = new_content.find("</TOOL_CODE>")
tool_call_str = new_content[start_index:end_index].strip()
print(f"Tool call detected: {tool_call_str}")
try:
if tool_call_str.startswith("search_tavily("):
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'")
tool_output = search_tavily(query)
print(f"Tool result: {tool_output[:200]}...")
current_response += f"\n\nTool Result from {tool_call_str}:\n{tool_output}\n"
else:
tool_output = f"Unknown tool: {tool_call_str}"
print(f"Error: {tool_output}")
current_response += f"\n\n{tool_output}\n"
except Exception as tool_e:
tool_output = f"Error running tool {tool_call_str}: {tool_e}"
print(f"Error: {tool_output}")
current_response += f"\n\n{tool_output}\n"
else:
final_answer = new_content
print(f"Final answer from model:\n{final_answer}")
return final_answer.strip()
return "Agent could not complete the task within the allowed iterations. Latest response: " + new_content.strip()