File size: 5,988 Bytes
13755f8 21e96d1 c3c803a 21e96d1 bf58062 21e96d1 15b9880 21e96d1 8bfc54d 21e96d1 15b9880 21e96d1 b6da6a3 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 21e96d1 b44b8a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch # För att kontrollera enheter
# Importera ditt nya sökverktyg
from tools.tavily_search import search_tavily
class GaiaAgent:
def __init__(self, model_id: str = "google/gemma-2b-it"):
# Ladda tokenizer och modell manuellt. Detta ger mer kontroll.
try:
print(f"Laddar tokenizer för {model_id}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN"))
print(f"Laddar modell för {model_id}...")
# Kontrollera om GPU är tillgänglig
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Använder enhet: {device}")
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # Använd bfloat16 för minskat minne
device_map="auto", # Accelerate hanterar detta över CPU/GPU
token=os.getenv("HF_TOKEN")
)
print("Modell laddad framgångsrikt.")
# Skapa en pipeline för textgenerering
self.text_generator = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
# device=0 if device == "cuda" else -1 # 0 för första GPU, -1 för CPU
)
print("Textgenereringspipeline skapad.")
except Exception as e:
print(f"Fel vid initiering av agent: {e}")
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}")
# --- THIS IS THE MISSING __CALL__ METHOD ---
def __call__(self, question: str) -> str:
"""
Denna metod gör att en instans av GaiaAgent kan kallas som en funktion.
Den kommer att anropa din process_task metod för att generera svaret.
"""
print(f"Agent received question (first 50 chars): {question[:50]}...")
result = self.process_task(question)
print(f"Agent returning answer: {result[:100]}...") # För att inte fylla loggarna med för långa svar
return result
# --- END OF MISSING METHOD ---
def process_task(self, task_description: str) -> str:
# Instruction to the LLM to perform the task and use tools.
# We need to build a prompt that instructs the model to use tools.
prompt = f"""
You are a helpful and expert AI assistant with access to a search tool.
Your task is to carefully and accurately answer questions by using the search tool when necessary.
Always provide a complete and correct answer based on the information you find.
Your available tools:
1. search_tavily(query: str): Searches on Tavily and returns relevant results.
Use this tool to find information on the internet that you don't know or need to verify.
To use a tool, write it in the following exact format:
<TOOL_CODE>
tool_name("your search query")
</TOOL_CODE>
Example:
If you need to know the capital of France:
<TOOL_CODE>
search_tavily("capital of France")
</TOOL_CODE>
When you have found all the necessary information and are ready to answer the task, provide your final answer.
Task: {task_description}
"""
max_iterations = 3
current_response = ""
for i in range(max_iterations):
full_prompt = prompt + current_response + "\n\nWhat is the next step or your final answer?"
print(f"[{i+1}/{max_iterations}] Generating response with prompt length: {len(full_prompt)}")
generated_text = self.text_generator(
full_prompt,
max_new_tokens=1024, # Behold 1024 eller öka om behövs
num_return_sequences=1,
pad_token_id=self.tokenizer.eos_token_id,
do_sample=True,
top_k=50, top_p=0.95,
temperature=0.8 # Behold 0.8 eller justera vid behov
)[0]['generated_text']
new_content = generated_text[len(full_prompt):].strip()
print(f"DEBUG - Full generated_text: \n---START---\n{generated_text}\n---END---")
print(f"DEBUG - Extracted new_content: '{new_content}'")
if "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content:
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>")
end_index = new_content.find("</TOOL_CODE>")
tool_call_str = new_content[start_index:end_index].strip()
print(f"Tool call detected: {tool_call_str}")
try:
if tool_call_str.startswith("search_tavily("):
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'")
tool_output = search_tavily(query)
print(f"Tool result: {tool_output[:200]}...")
current_response += f"\n\nTool Result from {tool_call_str}:\n{tool_output}\n"
else:
tool_output = f"Unknown tool: {tool_call_str}"
print(f"Error: {tool_output}")
current_response += f"\n\n{tool_output}\n"
except Exception as tool_e:
tool_output = f"Error running tool {tool_call_str}: {tool_e}"
print(f"Error: {tool_output}")
current_response += f"\n\n{tool_output}\n"
else:
final_answer = new_content
print(f"Final answer from model:\n{final_answer}")
return final_answer.strip()
return "Agent could not complete the task within the allowed iterations. Latest response: " + new_content.strip()
|