GAIA_Agent / agents /role_agent.py
Delanoe Pirard
Lot's of changes
114747f
import os
import logging
import datasets
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES
from llama_index.core.tools import FunctionTool
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.retrievers.bm25 import BM25Retriever
# Setup logging
logger = logging.getLogger(__name__)
# --- Tool Function ---
# Note: This function now relies on being bound to an instance of RoleAgentInitializer
# or having retriever/reranker passed to it.
# We will bind it to the instance method within the class.
# --- Initializer Class ---
class RoleAgentInitializer:
def __init__(self):
logger.info("Initializing RoleAgent resources...")
# Configuration from environment variables
self.embed_model_name = os.getenv("ROLE_EMBED_MODEL", "Snowflake/snowflake-arctic-embed-l-v2.0")
self.reranker_model_name = os.getenv("ROLE_RERANKER_MODEL", "Alibaba-NLP/gte-multilingual-reranker-base")
self.dataset_name = os.getenv("ROLE_PROMPT_DATASET", "fka/awesome-chatgpt-prompts")
self.llm_model_name = os.getenv("ROLE_LLM_MODEL", "gemini-2.5-pro-preview-03-25")
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
if not self.gemini_api_key:
logger.error("GEMINI_API_KEY not found in environment variables.")
raise ValueError("GEMINI_API_KEY must be set")
# Initialize models and components
try:
logger.info(f"Loading embed model: {self.embed_model_name}")
self.embed_model = HuggingFaceEmbedding(model_name=self.embed_model_name)
logger.info(f"Loading reranker model: {self.reranker_model_name}")
self.reranker = SentenceTransformerRerank(
model=self.reranker_model_name,
top_n=3
)
# Load the dataset
logger.info(f"Loading dataset: {self.dataset_name}")
prompts_dataset = datasets.load_dataset(self.dataset_name, split="train")
# Convert the dataset to a list of Documents
logger.info("Converting dataset to LlamaIndex Documents...")
documents = [
Document(
text="\n".join([
f"Act: {prompts_dataset['act'][i]}",
f"Prompt: {prompts_dataset['prompt'][i]}",
]),
metadata={"act": prompts_dataset["act"][i]}
)
for i in range(len(prompts_dataset))
]
splitter = SentenceSplitter(chunk_size=256, chunk_overlap=20)
logger.info("Building vector index (this may take time)...")
index = VectorStoreIndex.from_documents(
documents,
embed_model=self.embed_model,
show_progress=True,
transformations=[splitter]
)
logger.info("Vector index built.")
logger.info("Building BM25 retriever...")
bm25_retriever = BM25Retriever.from_defaults(
docstore=index.docstore,
similarity_top_k=2
)
vector_retriever = index.as_retriever(similarity_top_k=2)
logger.info("Building query fusion retriever...")
self.retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=2,
mode=FUSION_MODES.RECIPROCAL_RANK,
verbose=True,
)
logger.info("RoleAgent resources initialized successfully.")
except Exception as e:
logger.error(f"Error during RoleAgent resource initialization: {e}", exc_info=True)
raise
def role_prompt_retriever_method(self, query: str) -> str:
"""
Instance method to retrieve and return detailed role or task information.
Uses the retriever and reranker initialized in this class instance.
Args:
query (str): The user query describing the desired role, task, or prompt context.
Returns:
str: A string containing the assigned role/task description, or a message indicating no matching prompt was found.
"""
logger.info(f"Role prompt retriever called with query: {query[:100]}...")
try:
results = self.retriever.retrieve(query)
reranked_results = self.reranker.postprocess_nodes(results, query_str=query)
if reranked_results:
# Return top 3 results as per original logic
top_results_text = "\n\n".join([node.get_content() for node in reranked_results[:3]])
logger.info(f"Retrieved and reranked {len(reranked_results)} results. Returning top 3.")
return top_results_text
else:
logger.warning("No matching role prompt found after reranking.")
return "No matching role prompt found."
except Exception as e:
logger.error(f"Error during role prompt retrieval: {e}", exc_info=True)
return f"Error retrieving role prompt: {e}"
def get_agent(self) -> ReActAgent:
"""Creates and returns the configured ReActAgent for role selection."""
logger.info("Creating RoleAgent ReActAgent instance...")
# Create the tool, binding the method to this instance
role_prompt_retriever_tool = FunctionTool.from_defaults(
fn=self.role_prompt_retriever_method, # Use the instance method
name="role_prompt_retriever",
description="Retrieve and summarize the top three role or task prompts for "
"a query using BM25 and embedding retrieval with reranking.",
)
# System prompt (consider loading from file in future)
system_prompt = """\
You are RoleAgent, an expert context‐setter that interprets user inputs and deterministically assigns the most fitting persona or task schema to guide downstream agents. For every query:
1. **Interpret Intent**: Parse the user’s instruction to understand their goal, domain, and required expertise.
2. **Retrieve & Rank**: Use the `role_prompt_retriever` tool to fetch the top role descriptions relevant to the intent.
3. **Select Role**: Based *only* on the retrieved results, choose the single best‐matching persona (e.g. “Developer Assistant,” “SEO Strategist,” “Translation Engine,” “Terminal Emulator”) without asking the user any follow-up. If no relevant role is found, state that clearly.
4. **Respond**: Output in plain text with:
- **Role**: The selected persona (or "None Found").
- **Reason**: Briefly explain why this role was chosen based *only* on the retrieved text.
- **Prompt**: The corresponding role prompt from the retrieved text to be used by downstream agents (or "N/A" if none found).
5. **Hand-Off**: Immediately after including the chosen prompt (or N/A) in your response, invoke `planner_agent` to begin breaking down the user’s request into actionable sub-questions.
Always conclude your response with the full prompt for the next agent (or "N/A") and the invocation instruction for `planner_agent`.
"""
llm = GoogleGenAI(
api_key=self.gemini_api_key,
model=self.llm_model_name,
temperature=0.05
)
agent = ReActAgent(
name="role_agent",
description=(
"RoleAgent selects the most appropriate persona or task template based on the user’s query. "
"By evaluating the question’s intent and context using a specialized retriever, it chooses or refines a prompt that aligns "
"with the best-fitting role—whether developer, analyst, translator, planner, or otherwise—so that "
"subsequent agents can respond effectively under the optimal role context."
),
tools=[role_prompt_retriever_tool],
llm=llm,
system_prompt=system_prompt,
can_handoff_to=["planner_agent"],
)
logger.info("RoleAgent ReActAgent instance created.")
return agent
# --- Global Initializer Instance (Singleton Pattern) ---
# Instantiate the initializer once when the module is loaded.
# This ensures expensive operations (model loading, index building) happen only once.
_role_agent_initializer_instance = None
def get_initializer():
global _role_agent_initializer_instance
if _role_agent_initializer_instance is None:
logger.info("Instantiating RoleAgentInitializer for the first time.")
_role_agent_initializer_instance = RoleAgentInitializer()
return _role_agent_initializer_instance
# --- Public Initialization Function ---
def initialize_role_agent() -> ReActAgent:
"""Initializes and returns the Role Agent.
Uses a singleton pattern to ensure resources are loaded only once.
"""
logger.info("initialize_role_agent called.")
initializer = get_initializer()
return initializer.get_agent()
# Example usage (for testing if run directly)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger.info("Running role_agent.py directly for testing...")
# Ensure API key is set for testing
if not os.getenv("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.")
else:
try:
test_agent = initialize_role_agent()
print("Role Agent initialized successfully for testing.")
# You could add a simple test query here if needed
# e.g., result = test_agent.chat("act as a linux terminal")
# print(f"Test query result: {result}")
except Exception as e:
print(f"Error during testing: {e}")