Spaces:
Running
Running
import os | |
import logging | |
import datasets | |
from llama_index.core import Document, VectorStoreIndex | |
from llama_index.core.agent.workflow import ReActAgent | |
from llama_index.core.retrievers import QueryFusionRetriever | |
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES | |
from llama_index.core.tools import FunctionTool | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.node_parser import SentenceSplitter | |
from llama_index.core.postprocessor import SentenceTransformerRerank | |
from llama_index.llms.google_genai import GoogleGenAI | |
from llama_index.retrievers.bm25 import BM25Retriever | |
# Setup logging | |
logger = logging.getLogger(__name__) | |
# --- Tool Function --- | |
# Note: This function now relies on being bound to an instance of RoleAgentInitializer | |
# or having retriever/reranker passed to it. | |
# We will bind it to the instance method within the class. | |
# --- Initializer Class --- | |
class RoleAgentInitializer: | |
def __init__(self): | |
logger.info("Initializing RoleAgent resources...") | |
# Configuration from environment variables | |
self.embed_model_name = os.getenv("ROLE_EMBED_MODEL", "Snowflake/snowflake-arctic-embed-l-v2.0") | |
self.reranker_model_name = os.getenv("ROLE_RERANKER_MODEL", "Alibaba-NLP/gte-multilingual-reranker-base") | |
self.dataset_name = os.getenv("ROLE_PROMPT_DATASET", "fka/awesome-chatgpt-prompts") | |
self.llm_model_name = os.getenv("ROLE_LLM_MODEL", "gemini-2.5-pro-preview-03-25") | |
self.gemini_api_key = os.getenv("GEMINI_API_KEY") | |
if not self.gemini_api_key: | |
logger.error("GEMINI_API_KEY not found in environment variables.") | |
raise ValueError("GEMINI_API_KEY must be set") | |
# Initialize models and components | |
try: | |
logger.info(f"Loading embed model: {self.embed_model_name}") | |
self.embed_model = HuggingFaceEmbedding(model_name=self.embed_model_name) | |
logger.info(f"Loading reranker model: {self.reranker_model_name}") | |
self.reranker = SentenceTransformerRerank( | |
model=self.reranker_model_name, | |
top_n=3 | |
) | |
# Load the dataset | |
logger.info(f"Loading dataset: {self.dataset_name}") | |
prompts_dataset = datasets.load_dataset(self.dataset_name, split="train") | |
# Convert the dataset to a list of Documents | |
logger.info("Converting dataset to LlamaIndex Documents...") | |
documents = [ | |
Document( | |
text="\n".join([ | |
f"Act: {prompts_dataset['act'][i]}", | |
f"Prompt: {prompts_dataset['prompt'][i]}", | |
]), | |
metadata={"act": prompts_dataset["act"][i]} | |
) | |
for i in range(len(prompts_dataset)) | |
] | |
splitter = SentenceSplitter(chunk_size=256, chunk_overlap=20) | |
logger.info("Building vector index (this may take time)...") | |
index = VectorStoreIndex.from_documents( | |
documents, | |
embed_model=self.embed_model, | |
show_progress=True, | |
transformations=[splitter] | |
) | |
logger.info("Vector index built.") | |
logger.info("Building BM25 retriever...") | |
bm25_retriever = BM25Retriever.from_defaults( | |
docstore=index.docstore, | |
similarity_top_k=2 | |
) | |
vector_retriever = index.as_retriever(similarity_top_k=2) | |
logger.info("Building query fusion retriever...") | |
self.retriever = QueryFusionRetriever( | |
[vector_retriever, bm25_retriever], | |
similarity_top_k=2, | |
mode=FUSION_MODES.RECIPROCAL_RANK, | |
verbose=True, | |
) | |
logger.info("RoleAgent resources initialized successfully.") | |
except Exception as e: | |
logger.error(f"Error during RoleAgent resource initialization: {e}", exc_info=True) | |
raise | |
def role_prompt_retriever_method(self, query: str) -> str: | |
""" | |
Instance method to retrieve and return detailed role or task information. | |
Uses the retriever and reranker initialized in this class instance. | |
Args: | |
query (str): The user query describing the desired role, task, or prompt context. | |
Returns: | |
str: A string containing the assigned role/task description, or a message indicating no matching prompt was found. | |
""" | |
logger.info(f"Role prompt retriever called with query: {query[:100]}...") | |
try: | |
results = self.retriever.retrieve(query) | |
reranked_results = self.reranker.postprocess_nodes(results, query_str=query) | |
if reranked_results: | |
# Return top 3 results as per original logic | |
top_results_text = "\n\n".join([node.get_content() for node in reranked_results[:3]]) | |
logger.info(f"Retrieved and reranked {len(reranked_results)} results. Returning top 3.") | |
return top_results_text | |
else: | |
logger.warning("No matching role prompt found after reranking.") | |
return "No matching role prompt found." | |
except Exception as e: | |
logger.error(f"Error during role prompt retrieval: {e}", exc_info=True) | |
return f"Error retrieving role prompt: {e}" | |
def get_agent(self) -> ReActAgent: | |
"""Creates and returns the configured ReActAgent for role selection.""" | |
logger.info("Creating RoleAgent ReActAgent instance...") | |
# Create the tool, binding the method to this instance | |
role_prompt_retriever_tool = FunctionTool.from_defaults( | |
fn=self.role_prompt_retriever_method, # Use the instance method | |
name="role_prompt_retriever", | |
description="Retrieve and summarize the top three role or task prompts for " | |
"a query using BM25 and embedding retrieval with reranking.", | |
) | |
# System prompt (consider loading from file in future) | |
system_prompt = """\ | |
You are RoleAgent, an expert context‐setter that interprets user inputs and deterministically assigns the most fitting persona or task schema to guide downstream agents. For every query: | |
1. **Interpret Intent**: Parse the user’s instruction to understand their goal, domain, and required expertise. | |
2. **Retrieve & Rank**: Use the `role_prompt_retriever` tool to fetch the top role descriptions relevant to the intent. | |
3. **Select Role**: Based *only* on the retrieved results, choose the single best‐matching persona (e.g. “Developer Assistant,” “SEO Strategist,” “Translation Engine,” “Terminal Emulator”) without asking the user any follow-up. If no relevant role is found, state that clearly. | |
4. **Respond**: Output in plain text with: | |
- **Role**: The selected persona (or "None Found"). | |
- **Reason**: Briefly explain why this role was chosen based *only* on the retrieved text. | |
- **Prompt**: The corresponding role prompt from the retrieved text to be used by downstream agents (or "N/A" if none found). | |
5. **Hand-Off**: Immediately after including the chosen prompt (or N/A) in your response, invoke `planner_agent` to begin breaking down the user’s request into actionable sub-questions. | |
Always conclude your response with the full prompt for the next agent (or "N/A") and the invocation instruction for `planner_agent`. | |
""" | |
llm = GoogleGenAI( | |
api_key=self.gemini_api_key, | |
model=self.llm_model_name, | |
temperature=0.05 | |
) | |
agent = ReActAgent( | |
name="role_agent", | |
description=( | |
"RoleAgent selects the most appropriate persona or task template based on the user’s query. " | |
"By evaluating the question’s intent and context using a specialized retriever, it chooses or refines a prompt that aligns " | |
"with the best-fitting role—whether developer, analyst, translator, planner, or otherwise—so that " | |
"subsequent agents can respond effectively under the optimal role context." | |
), | |
tools=[role_prompt_retriever_tool], | |
llm=llm, | |
system_prompt=system_prompt, | |
can_handoff_to=["planner_agent"], | |
) | |
logger.info("RoleAgent ReActAgent instance created.") | |
return agent | |
# --- Global Initializer Instance (Singleton Pattern) --- | |
# Instantiate the initializer once when the module is loaded. | |
# This ensures expensive operations (model loading, index building) happen only once. | |
_role_agent_initializer_instance = None | |
def get_initializer(): | |
global _role_agent_initializer_instance | |
if _role_agent_initializer_instance is None: | |
logger.info("Instantiating RoleAgentInitializer for the first time.") | |
_role_agent_initializer_instance = RoleAgentInitializer() | |
return _role_agent_initializer_instance | |
# --- Public Initialization Function --- | |
def initialize_role_agent() -> ReActAgent: | |
"""Initializes and returns the Role Agent. | |
Uses a singleton pattern to ensure resources are loaded only once. | |
""" | |
logger.info("initialize_role_agent called.") | |
initializer = get_initializer() | |
return initializer.get_agent() | |
# Example usage (for testing if run directly) | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger.info("Running role_agent.py directly for testing...") | |
# Ensure API key is set for testing | |
if not os.getenv("GEMINI_API_KEY"): | |
print("Error: GEMINI_API_KEY environment variable not set. Cannot run test.") | |
else: | |
try: | |
test_agent = initialize_role_agent() | |
print("Role Agent initialized successfully for testing.") | |
# You could add a simple test query here if needed | |
# e.g., result = test_agent.chat("act as a linux terminal") | |
# print(f"Test query result: {result}") | |
except Exception as e: | |
print(f"Error during testing: {e}") | |