File size: 10,350 Bytes
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114747f
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114747f
a23082c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import logging

import datasets
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES
from llama_index.core.tools import FunctionTool
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.retrievers.bm25 import BM25Retriever


# Setup logging
logger = logging.getLogger(__name__)

# --- Tool Function --- 
# Note: This function now relies on being bound to an instance of RoleAgentInitializer
# or having retriever/reranker passed to it.
# We will bind it to the instance method within the class.

# --- Initializer Class --- 
class RoleAgentInitializer:
    def __init__(self):
        logger.info("Initializing RoleAgent resources...")
        # Configuration from environment variables
        self.embed_model_name = os.getenv("ROLE_EMBED_MODEL", "Snowflake/snowflake-arctic-embed-l-v2.0")
        self.reranker_model_name = os.getenv("ROLE_RERANKER_MODEL", "Alibaba-NLP/gte-multilingual-reranker-base")
        self.dataset_name = os.getenv("ROLE_PROMPT_DATASET", "fka/awesome-chatgpt-prompts")
        self.llm_model_name = os.getenv("ROLE_LLM_MODEL", "gemini-2.5-pro-preview-03-25")
        self.gemini_api_key = os.getenv("GEMINI_API_KEY")

        if not self.gemini_api_key:
            logger.error("GEMINI_API_KEY not found in environment variables.")
            raise ValueError("GEMINI_API_KEY must be set")

        # Initialize models and components
        try:
            logger.info(f"Loading embed model: {self.embed_model_name}")
            self.embed_model = HuggingFaceEmbedding(model_name=self.embed_model_name)
            
            logger.info(f"Loading reranker model: {self.reranker_model_name}")
            self.reranker = SentenceTransformerRerank(
                model=self.reranker_model_name,
                top_n=3
            )

            # Load the dataset
            logger.info(f"Loading dataset: {self.dataset_name}")
            prompts_dataset = datasets.load_dataset(self.dataset_name, split="train")

            # Convert the dataset to a list of Documents
            logger.info("Converting dataset to LlamaIndex Documents...")
            documents = [
                Document(
                    text="\n".join([
                        f"Act: {prompts_dataset['act'][i]}",
                        f"Prompt: {prompts_dataset['prompt'][i]}",
                    ]),
                    metadata={"act": prompts_dataset["act"][i]}
                )
                for i in range(len(prompts_dataset))
            ]

            splitter = SentenceSplitter(chunk_size=256, chunk_overlap=20)

            logger.info("Building vector index (this may take time)...")
            index = VectorStoreIndex.from_documents(
                documents,
                embed_model=self.embed_model,
                show_progress=True,
                transformations=[splitter]
            )
            logger.info("Vector index built.")

            logger.info("Building BM25 retriever...")
            bm25_retriever = BM25Retriever.from_defaults(
                docstore=index.docstore,
                similarity_top_k=2
            )
            vector_retriever = index.as_retriever(similarity_top_k=2)

            logger.info("Building query fusion retriever...")
            self.retriever = QueryFusionRetriever(
                [vector_retriever, bm25_retriever],
                similarity_top_k=2,
                mode=FUSION_MODES.RECIPROCAL_RANK,
                verbose=True,
            )
            logger.info("RoleAgent resources initialized successfully.")

        except Exception as e:
            logger.error(f"Error during RoleAgent resource initialization: {e}", exc_info=True)
            raise

    def role_prompt_retriever_method(self, query: str) -> str:
        """
        Instance method to retrieve and return detailed role or task information.
        Uses the retriever and reranker initialized in this class instance.
        Args:
            query (str): The user query describing the desired role, task, or prompt context.
        Returns:
            str: A string containing the assigned role/task description, or a message indicating no matching prompt was found.
        """
        logger.info(f"Role prompt retriever called with query: {query[:100]}...")
        try:
            results = self.retriever.retrieve(query)
            reranked_results = self.reranker.postprocess_nodes(results, query_str=query)
            if reranked_results:
                # Return top 3 results as per original logic
                top_results_text = "\n\n".join([node.get_content() for node in reranked_results[:3]])
                logger.info(f"Retrieved and reranked {len(reranked_results)} results. Returning top 3.")
                return top_results_text
            else:
                logger.warning("No matching role prompt found after reranking.")
                return "No matching role prompt found."
        except Exception as e:
            logger.error(f"Error during role prompt retrieval: {e}", exc_info=True)
            return f"Error retrieving role prompt: {e}"

    def get_agent(self) -> ReActAgent:
        """Creates and returns the configured ReActAgent for role selection."""
        logger.info("Creating RoleAgent ReActAgent instance...")
        
        # Create the tool, binding the method to this instance
        role_prompt_retriever_tool = FunctionTool.from_defaults(
            fn=self.role_prompt_retriever_method, # Use the instance method
            name="role_prompt_retriever",
            description="Retrieve and summarize the top three role or task prompts for "
                        "a query using BM25 and embedding retrieval with reranking.",
        )

        # System prompt (consider loading from file in future)
        system_prompt = """\
        You are RoleAgent, an expert context‐setter that interprets user inputs and deterministically assigns the most fitting persona or task schema to guide downstream agents. For every query:

        1. **Interpret Intent**: Parse the user’s instruction to understand their goal, domain, and required expertise.  
        2. **Retrieve & Rank**: Use the `role_prompt_retriever` tool to fetch the top role descriptions relevant to the intent.  
        3. **Select Role**: Based *only* on the retrieved results, choose the single best‐matching persona (e.g. “Developer Assistant,” “SEO Strategist,” “Translation Engine,” “Terminal Emulator”) without asking the user any follow-up. If no relevant role is found, state that clearly.
        4. **Respond**: Output in plain text with:  
           - **Role**: The selected persona (or "None Found").  
           - **Reason**: Briefly explain why this role was chosen based *only* on the retrieved text.  
           - **Prompt**: The corresponding role prompt from the retrieved text to be used by downstream agents (or "N/A" if none found).

        5. **Hand-Off**: Immediately after including the chosen prompt (or N/A) in your response, invoke `planner_agent` to begin breaking down the user’s request into actionable sub-questions.

        Always conclude your response with the full prompt for the next agent (or "N/A") and the invocation instruction for `planner_agent`.
        """

        llm = GoogleGenAI(
            api_key=self.gemini_api_key,
            model=self.llm_model_name,
            temperature=0.05
        )

        agent = ReActAgent(
            name="role_agent",
            description=(
                "RoleAgent selects the most appropriate persona or task template based on the user’s query. "
                "By evaluating the question’s intent and context using a specialized retriever, it chooses or refines a prompt that aligns "
                "with the best-fitting role—whether developer, analyst, translator, planner, or otherwise—so that "
                "subsequent agents can respond effectively under the optimal role context."
            ),
            tools=[role_prompt_retriever_tool],
            llm=llm,
            system_prompt=system_prompt,
            can_handoff_to=["planner_agent"],
        )
        logger.info("RoleAgent ReActAgent instance created.")
        return agent

# --- Global Initializer Instance (Singleton Pattern) ---
# Instantiate the initializer once when the module is loaded.
# This ensures expensive operations (model loading, index building) happen only once.
_role_agent_initializer_instance = None

def get_initializer():
    global _role_agent_initializer_instance
    if _role_agent_initializer_instance is None:
        logger.info("Instantiating RoleAgentInitializer for the first time.")
        _role_agent_initializer_instance = RoleAgentInitializer()
    return _role_agent_initializer_instance

# --- Public Initialization Function --- 
def initialize_role_agent() -> ReActAgent:
    """Initializes and returns the Role Agent.
    Uses a singleton pattern to ensure resources are loaded only once.
    """
    logger.info("initialize_role_agent called.")
    initializer = get_initializer()
    return initializer.get_agent()

# Example usage (for testing if run directly)
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    logger.info("Running role_agent.py directly for testing...")
    
    # Ensure API key is set for testing
    if not os.getenv("GEMINI_API_KEY"):
        print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.")
    else:
        try:
            test_agent = initialize_role_agent()
            print("Role Agent initialized successfully for testing.")
            # You could add a simple test query here if needed
            # e.g., result = test_agent.chat("act as a linux terminal")
            # print(f"Test query result: {result}")
        except Exception as e:
            print(f"Error during testing: {e}")