Spaces:
Running
Running
File size: 10,350 Bytes
a23082c 114747f a23082c 114747f a23082c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import logging
import datasets
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.agent.workflow import ReActAgent
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES
from llama_index.core.tools import FunctionTool
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.llms.google_genai import GoogleGenAI
from llama_index.retrievers.bm25 import BM25Retriever
# Setup logging
logger = logging.getLogger(__name__)
# --- Tool Function ---
# Note: This function now relies on being bound to an instance of RoleAgentInitializer
# or having retriever/reranker passed to it.
# We will bind it to the instance method within the class.
# --- Initializer Class ---
class RoleAgentInitializer:
def __init__(self):
logger.info("Initializing RoleAgent resources...")
# Configuration from environment variables
self.embed_model_name = os.getenv("ROLE_EMBED_MODEL", "Snowflake/snowflake-arctic-embed-l-v2.0")
self.reranker_model_name = os.getenv("ROLE_RERANKER_MODEL", "Alibaba-NLP/gte-multilingual-reranker-base")
self.dataset_name = os.getenv("ROLE_PROMPT_DATASET", "fka/awesome-chatgpt-prompts")
self.llm_model_name = os.getenv("ROLE_LLM_MODEL", "gemini-2.5-pro-preview-03-25")
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
if not self.gemini_api_key:
logger.error("GEMINI_API_KEY not found in environment variables.")
raise ValueError("GEMINI_API_KEY must be set")
# Initialize models and components
try:
logger.info(f"Loading embed model: {self.embed_model_name}")
self.embed_model = HuggingFaceEmbedding(model_name=self.embed_model_name)
logger.info(f"Loading reranker model: {self.reranker_model_name}")
self.reranker = SentenceTransformerRerank(
model=self.reranker_model_name,
top_n=3
)
# Load the dataset
logger.info(f"Loading dataset: {self.dataset_name}")
prompts_dataset = datasets.load_dataset(self.dataset_name, split="train")
# Convert the dataset to a list of Documents
logger.info("Converting dataset to LlamaIndex Documents...")
documents = [
Document(
text="\n".join([
f"Act: {prompts_dataset['act'][i]}",
f"Prompt: {prompts_dataset['prompt'][i]}",
]),
metadata={"act": prompts_dataset["act"][i]}
)
for i in range(len(prompts_dataset))
]
splitter = SentenceSplitter(chunk_size=256, chunk_overlap=20)
logger.info("Building vector index (this may take time)...")
index = VectorStoreIndex.from_documents(
documents,
embed_model=self.embed_model,
show_progress=True,
transformations=[splitter]
)
logger.info("Vector index built.")
logger.info("Building BM25 retriever...")
bm25_retriever = BM25Retriever.from_defaults(
docstore=index.docstore,
similarity_top_k=2
)
vector_retriever = index.as_retriever(similarity_top_k=2)
logger.info("Building query fusion retriever...")
self.retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=2,
mode=FUSION_MODES.RECIPROCAL_RANK,
verbose=True,
)
logger.info("RoleAgent resources initialized successfully.")
except Exception as e:
logger.error(f"Error during RoleAgent resource initialization: {e}", exc_info=True)
raise
def role_prompt_retriever_method(self, query: str) -> str:
"""
Instance method to retrieve and return detailed role or task information.
Uses the retriever and reranker initialized in this class instance.
Args:
query (str): The user query describing the desired role, task, or prompt context.
Returns:
str: A string containing the assigned role/task description, or a message indicating no matching prompt was found.
"""
logger.info(f"Role prompt retriever called with query: {query[:100]}...")
try:
results = self.retriever.retrieve(query)
reranked_results = self.reranker.postprocess_nodes(results, query_str=query)
if reranked_results:
# Return top 3 results as per original logic
top_results_text = "\n\n".join([node.get_content() for node in reranked_results[:3]])
logger.info(f"Retrieved and reranked {len(reranked_results)} results. Returning top 3.")
return top_results_text
else:
logger.warning("No matching role prompt found after reranking.")
return "No matching role prompt found."
except Exception as e:
logger.error(f"Error during role prompt retrieval: {e}", exc_info=True)
return f"Error retrieving role prompt: {e}"
def get_agent(self) -> ReActAgent:
"""Creates and returns the configured ReActAgent for role selection."""
logger.info("Creating RoleAgent ReActAgent instance...")
# Create the tool, binding the method to this instance
role_prompt_retriever_tool = FunctionTool.from_defaults(
fn=self.role_prompt_retriever_method, # Use the instance method
name="role_prompt_retriever",
description="Retrieve and summarize the top three role or task prompts for "
"a query using BM25 and embedding retrieval with reranking.",
)
# System prompt (consider loading from file in future)
system_prompt = """\
You are RoleAgent, an expert context‐setter that interprets user inputs and deterministically assigns the most fitting persona or task schema to guide downstream agents. For every query:
1. **Interpret Intent**: Parse the user’s instruction to understand their goal, domain, and required expertise.
2. **Retrieve & Rank**: Use the `role_prompt_retriever` tool to fetch the top role descriptions relevant to the intent.
3. **Select Role**: Based *only* on the retrieved results, choose the single best‐matching persona (e.g. “Developer Assistant,” “SEO Strategist,” “Translation Engine,” “Terminal Emulator”) without asking the user any follow-up. If no relevant role is found, state that clearly.
4. **Respond**: Output in plain text with:
- **Role**: The selected persona (or "None Found").
- **Reason**: Briefly explain why this role was chosen based *only* on the retrieved text.
- **Prompt**: The corresponding role prompt from the retrieved text to be used by downstream agents (or "N/A" if none found).
5. **Hand-Off**: Immediately after including the chosen prompt (or N/A) in your response, invoke `planner_agent` to begin breaking down the user’s request into actionable sub-questions.
Always conclude your response with the full prompt for the next agent (or "N/A") and the invocation instruction for `planner_agent`.
"""
llm = GoogleGenAI(
api_key=self.gemini_api_key,
model=self.llm_model_name,
temperature=0.05
)
agent = ReActAgent(
name="role_agent",
description=(
"RoleAgent selects the most appropriate persona or task template based on the user’s query. "
"By evaluating the question’s intent and context using a specialized retriever, it chooses or refines a prompt that aligns "
"with the best-fitting role—whether developer, analyst, translator, planner, or otherwise—so that "
"subsequent agents can respond effectively under the optimal role context."
),
tools=[role_prompt_retriever_tool],
llm=llm,
system_prompt=system_prompt,
can_handoff_to=["planner_agent"],
)
logger.info("RoleAgent ReActAgent instance created.")
return agent
# --- Global Initializer Instance (Singleton Pattern) ---
# Instantiate the initializer once when the module is loaded.
# This ensures expensive operations (model loading, index building) happen only once.
_role_agent_initializer_instance = None
def get_initializer():
global _role_agent_initializer_instance
if _role_agent_initializer_instance is None:
logger.info("Instantiating RoleAgentInitializer for the first time.")
_role_agent_initializer_instance = RoleAgentInitializer()
return _role_agent_initializer_instance
# --- Public Initialization Function ---
def initialize_role_agent() -> ReActAgent:
"""Initializes and returns the Role Agent.
Uses a singleton pattern to ensure resources are loaded only once.
"""
logger.info("initialize_role_agent called.")
initializer = get_initializer()
return initializer.get_agent()
# Example usage (for testing if run directly)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger.info("Running role_agent.py directly for testing...")
# Ensure API key is set for testing
if not os.getenv("GEMINI_API_KEY"):
print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.")
else:
try:
test_agent = initialize_role_agent()
print("Role Agent initialized successfully for testing.")
# You could add a simple test query here if needed
# e.g., result = test_agent.chat("act as a linux terminal")
# print(f"Test query result: {result}")
except Exception as e:
print(f"Error during testing: {e}")
|