|
import os |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
from tools.tavily_search import search_tavily |
|
|
|
class GaiaAgent: |
|
def __init__(self, model_id: str = "google/gemma-2b-it"): |
|
|
|
try: |
|
print(f"Laddar tokenizer för {model_id}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("HF_TOKEN")) |
|
print(f"Laddar modell för {model_id}...") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Använder enhet: {device}") |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
token=os.getenv("HF_TOKEN") |
|
) |
|
print("Modell laddad framgångsrikt.") |
|
|
|
|
|
self.text_generator = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
|
|
) |
|
print("Textgenereringspipeline skapad.") |
|
|
|
except Exception as e: |
|
print(f"Fel vid initiering av agent: {e}") |
|
raise RuntimeError(f"Fel vid laddning av modell eller tokenizer: {e}") |
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
""" |
|
Denna metod gör att en instans av GaiaAgent kan kallas som en funktion. |
|
Den kommer att anropa din process_task metod för att generera svaret. |
|
""" |
|
print(f"Agent received question (first 50 chars): {question[:50]}...") |
|
result = self.process_task(question) |
|
print(f"Agent returning answer: {result[:100]}...") |
|
return result |
|
|
|
|
|
def process_task(self, task_description: str) -> str: |
|
|
|
|
|
|
|
prompt = f""" |
|
You are a helpful and expert AI assistant with access to a search tool. |
|
Your task is to carefully and accurately answer questions by using the search tool when necessary. |
|
Always provide a complete and correct answer based on the information you find. |
|
|
|
You must follow a Thought, Tool, Observation, Answer (TTOA) pattern. |
|
|
|
**Thought:** First, carefully consider the task. What information do you need to answer the question? Do you need to use a tool? |
|
**Tool:** If you need to search, use the search_tavily tool. The format is: <TOOL_CODE>search_tavily("your search query")</TOOL_CODE> |
|
**Observation:** After a tool call, you will receive an observation (the tool's output). |
|
**Answer:** Once you have gathered all necessary information, provide your final, concise answer directly. |
|
|
|
Your available tools: |
|
1. search_tavily(query: str): Searches on Tavily and returns relevant results. |
|
|
|
Example Interaction: |
|
Task: What is the capital of France? |
|
Thought: I need to find the capital of France. I should use the search_tavily tool. |
|
Tool: <TOOL_CODE>search_tavily("capital of France")</TOOL_CODE> |
|
Observation: The capital of France is Paris. |
|
Answer: The capital of France is Paris. |
|
|
|
Now, let's start. |
|
|
|
Task: {task_description} |
|
""" |
|
|
|
max_iterations = 3 |
|
current_response = "" |
|
|
|
for i in range(max_iterations): |
|
|
|
full_prompt = prompt + current_response + "\n\nThought:" |
|
|
|
print(f"[{i+1}/{max_iterations}] Generating response with prompt length: {len(full_prompt)}") |
|
|
|
generated_text = self.text_generator( |
|
full_prompt, |
|
max_new_tokens=1024, |
|
num_return_sequences=1, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
do_sample=True, |
|
top_k=50, top_p=0.95, |
|
temperature=0.7 |
|
)[0]['generated_text'] |
|
|
|
new_content = generated_text[len(full_prompt):].strip() |
|
print(f"DEBUG - Full generated_text: \n---START---\n{generated_text}\n---END---") |
|
print(f"DEBUG - Extracted new_content: '{new_content}'") |
|
|
|
|
|
if "Answer:" in new_content: |
|
final_answer = new_content.split("Answer:", 1)[1].strip() |
|
print(f"Final answer from model:\n{final_answer}") |
|
return final_answer |
|
elif "<TOOL_CODE>" in new_content and "</TOOL_CODE>" in new_content: |
|
start_index = new_content.find("<TOOL_CODE>") + len("<TOOL_CODE>") |
|
end_index = new_content.find("</TOOL_CODE>") |
|
tool_call_str = new_content[start_index:end_index].strip() |
|
|
|
print(f"Tool call detected: {tool_call_str}") |
|
|
|
try: |
|
if tool_call_str.startswith("search_tavily("): |
|
query = tool_call_str[len("search_tavily("):-1].strip().strip('"').strip("'") |
|
tool_output = search_tavily(query) |
|
print(f"Tool result: {tool_output[:200]}...") |
|
current_response += f"\n\nObservation: {tool_output}\n" |
|
else: |
|
tool_output = f"Unknown tool: {tool_call_str}" |
|
print(f"Error: {tool_output}") |
|
current_response += f"\n\nObservation: {tool_output}\n" |
|
except Exception as tool_e: |
|
tool_output = f"Error running tool {tool_call_str}: {tool_e}" |
|
print(f"Error: {tool_output}") |
|
current_response += f"\n\nObservation: {tool_output}\n" |
|
else: |
|
|
|
current_response += f"\n\n{new_content}\n" |
|
print(f"Model generated non-tool/non-answer content. Appending: {new_content[:100]}...") |
|
|
|
return "Agent could not complete the task within the allowed iterations. Latest response: " + new_content.strip() if new_content else "Agent could not complete the task within the allowed iterations. No meaningful content generated." |
|
|