import os import logging import datasets from llama_index.core import Document, VectorStoreIndex from llama_index.core.agent.workflow import ReActAgent from llama_index.core.retrievers import QueryFusionRetriever from llama_index.core.retrievers.fusion_retriever import FUSION_MODES from llama_index.core.tools import FunctionTool from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.core.node_parser import SentenceSplitter from llama_index.core.postprocessor import SentenceTransformerRerank from llama_index.llms.google_genai import GoogleGenAI from llama_index.retrievers.bm25 import BM25Retriever # Setup logging logger = logging.getLogger(__name__) # --- Tool Function --- # Note: This function now relies on being bound to an instance of RoleAgentInitializer # or having retriever/reranker passed to it. # We will bind it to the instance method within the class. # --- Initializer Class --- class RoleAgentInitializer: def __init__(self): logger.info("Initializing RoleAgent resources...") # Configuration from environment variables self.embed_model_name = os.getenv("ROLE_EMBED_MODEL", "Snowflake/snowflake-arctic-embed-l-v2.0") self.reranker_model_name = os.getenv("ROLE_RERANKER_MODEL", "Alibaba-NLP/gte-multilingual-reranker-base") self.dataset_name = os.getenv("ROLE_PROMPT_DATASET", "fka/awesome-chatgpt-prompts") self.llm_model_name = os.getenv("ROLE_LLM_MODEL", "gemini-2.5-pro-preview-03-25") self.gemini_api_key = os.getenv("GEMINI_API_KEY") if not self.gemini_api_key: logger.error("GEMINI_API_KEY not found in environment variables.") raise ValueError("GEMINI_API_KEY must be set") # Initialize models and components try: logger.info(f"Loading embed model: {self.embed_model_name}") self.embed_model = HuggingFaceEmbedding(model_name=self.embed_model_name) logger.info(f"Loading reranker model: {self.reranker_model_name}") self.reranker = SentenceTransformerRerank( model=self.reranker_model_name, top_n=3 ) # Load the dataset logger.info(f"Loading dataset: {self.dataset_name}") prompts_dataset = datasets.load_dataset(self.dataset_name, split="train") # Convert the dataset to a list of Documents logger.info("Converting dataset to LlamaIndex Documents...") documents = [ Document( text="\n".join([ f"Act: {prompts_dataset['act'][i]}", f"Prompt: {prompts_dataset['prompt'][i]}", ]), metadata={"act": prompts_dataset["act"][i]} ) for i in range(len(prompts_dataset)) ] splitter = SentenceSplitter(chunk_size=256, chunk_overlap=20) logger.info("Building vector index (this may take time)...") index = VectorStoreIndex.from_documents( documents, embed_model=self.embed_model, show_progress=True, transformations=[splitter] ) logger.info("Vector index built.") logger.info("Building BM25 retriever...") bm25_retriever = BM25Retriever.from_defaults( docstore=index.docstore, similarity_top_k=2 ) vector_retriever = index.as_retriever(similarity_top_k=2) logger.info("Building query fusion retriever...") self.retriever = QueryFusionRetriever( [vector_retriever, bm25_retriever], similarity_top_k=2, mode=FUSION_MODES.RECIPROCAL_RANK, verbose=True, ) logger.info("RoleAgent resources initialized successfully.") except Exception as e: logger.error(f"Error during RoleAgent resource initialization: {e}", exc_info=True) raise def role_prompt_retriever_method(self, query: str) -> str: """ Instance method to retrieve and return detailed role or task information. Uses the retriever and reranker initialized in this class instance. Args: query (str): The user query describing the desired role, task, or prompt context. Returns: str: A string containing the assigned role/task description, or a message indicating no matching prompt was found. """ logger.info(f"Role prompt retriever called with query: {query[:100]}...") try: results = self.retriever.retrieve(query) reranked_results = self.reranker.postprocess_nodes(results, query_str=query) if reranked_results: # Return top 3 results as per original logic top_results_text = "\n\n".join([node.get_content() for node in reranked_results[:3]]) logger.info(f"Retrieved and reranked {len(reranked_results)} results. Returning top 3.") return top_results_text else: logger.warning("No matching role prompt found after reranking.") return "No matching role prompt found." except Exception as e: logger.error(f"Error during role prompt retrieval: {e}", exc_info=True) return f"Error retrieving role prompt: {e}" def get_agent(self) -> ReActAgent: """Creates and returns the configured ReActAgent for role selection.""" logger.info("Creating RoleAgent ReActAgent instance...") # Create the tool, binding the method to this instance role_prompt_retriever_tool = FunctionTool.from_defaults( fn=self.role_prompt_retriever_method, # Use the instance method name="role_prompt_retriever", description="Retrieve and summarize the top three role or task prompts for " "a query using BM25 and embedding retrieval with reranking.", ) # System prompt (consider loading from file in future) system_prompt = """\ You are RoleAgent, an expert context‐setter that interprets user inputs and deterministically assigns the most fitting persona or task schema to guide downstream agents. For every query: 1. **Interpret Intent**: Parse the user’s instruction to understand their goal, domain, and required expertise. 2. **Retrieve & Rank**: Use the `role_prompt_retriever` tool to fetch the top role descriptions relevant to the intent. 3. **Select Role**: Based *only* on the retrieved results, choose the single best‐matching persona (e.g. “Developer Assistant,” “SEO Strategist,” “Translation Engine,” “Terminal Emulator”) without asking the user any follow-up. If no relevant role is found, state that clearly. 4. **Respond**: Output in plain text with: - **Role**: The selected persona (or "None Found"). - **Reason**: Briefly explain why this role was chosen based *only* on the retrieved text. - **Prompt**: The corresponding role prompt from the retrieved text to be used by downstream agents (or "N/A" if none found). 5. **Hand-Off**: Immediately after including the chosen prompt (or N/A) in your response, invoke `planner_agent` to begin breaking down the user’s request into actionable sub-questions. Always conclude your response with the full prompt for the next agent (or "N/A") and the invocation instruction for `planner_agent`. """ llm = GoogleGenAI( api_key=self.gemini_api_key, model=self.llm_model_name, temperature=0.05 ) agent = ReActAgent( name="role_agent", description=( "RoleAgent selects the most appropriate persona or task template based on the user’s query. " "By evaluating the question’s intent and context using a specialized retriever, it chooses or refines a prompt that aligns " "with the best-fitting role—whether developer, analyst, translator, planner, or otherwise—so that " "subsequent agents can respond effectively under the optimal role context." ), tools=[role_prompt_retriever_tool], llm=llm, system_prompt=system_prompt, can_handoff_to=["planner_agent"], ) logger.info("RoleAgent ReActAgent instance created.") return agent # --- Global Initializer Instance (Singleton Pattern) --- # Instantiate the initializer once when the module is loaded. # This ensures expensive operations (model loading, index building) happen only once. _role_agent_initializer_instance = None def get_initializer(): global _role_agent_initializer_instance if _role_agent_initializer_instance is None: logger.info("Instantiating RoleAgentInitializer for the first time.") _role_agent_initializer_instance = RoleAgentInitializer() return _role_agent_initializer_instance # --- Public Initialization Function --- def initialize_role_agent() -> ReActAgent: """Initializes and returns the Role Agent. Uses a singleton pattern to ensure resources are loaded only once. """ logger.info("initialize_role_agent called.") initializer = get_initializer() return initializer.get_agent() # Example usage (for testing if run directly) if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger.info("Running role_agent.py directly for testing...") # Ensure API key is set for testing if not os.getenv("GEMINI_API_KEY"): print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.") else: try: test_agent = initialize_role_agent() print("Role Agent initialized successfully for testing.") # You could add a simple test query here if needed # e.g., result = test_agent.chat("act as a linux terminal") # print(f"Test query result: {result}") except Exception as e: print(f"Error during testing: {e}")