Spaces:
Running
Running
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap # Not used, but kept from original | |
from datetime import datetime # Not used, but kept from original | |
from typing import Dict, List, Optional, Union, Any | |
import traceback | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s') | |
try: | |
from google import genai | |
from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc. | |
from google.genai import errors | |
# If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports. | |
# For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig | |
# and EmbedContentConfig might be implicit or part of task_type. | |
GENAI_AVAILABLE = True | |
logging.info("Google Generative AI library imported successfully.") | |
except ImportError: | |
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
GENAI_AVAILABLE = False | |
# Dummy classes for graceful degradation (simplified) | |
class genai: | |
Client = None | |
# If using google-generativeai, these would be different: | |
# GenerativeModel = None | |
# def configure(*args, **kwargs): pass | |
# def embed_content(*args, **kwargs): return {} | |
class types: # Placeholder for types used in the original code | |
EmbedContentConfig = None # Placeholder | |
GenerationConfig = None # Placeholder | |
SafetySetting = None | |
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) # Dummy for FinishReason | |
class HarmCategory: | |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: | |
BLOCK_NONE = "BLOCK_NONE" | |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" | |
class generation_types: # Dummy for BlockedPromptException | |
BlockedPromptException = type('BlockedPromptException', (Exception,), {}) | |
# --- Custom Exceptions --- | |
class ValidationError(Exception): | |
"""Custom validation error for agent inputs""" | |
pass | |
class RateLimitError(Exception): # Not used, but kept | |
"""Placeholder for rate limit errors.""" | |
pass | |
class AgentNotReadyError(Exception): | |
"""Agent is not properly initialized""" | |
pass | |
# --- Configuration Constants --- | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
LLM_MODEL_NAME = "gemini-1.5-flash-latest" # For google-generativeai, model name is directly used. | |
# For client.models.generate_content, it might need "models/gemini-1.5-flash-latest" | |
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Similarly, might need "models/text-embedding-004" | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 8192, # Ensure this is supported | |
"candidate_count": 1, | |
} | |
DEFAULT_SAFETY_SETTINGS = [] # User can populate this with {'category': HarmCategory.HARM_CATEGORY_X, 'threshold': HarmBlockThreshold.BLOCK_Y} | |
# Default RAG documents | |
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({ | |
'text': [ | |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.", | |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.", | |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.", | |
"Analyzing follower demographics and post engagement helps refine employer branding strategies.", | |
"Content strategy should align with company values to attract the right talent.", | |
"Employee advocacy programs can significantly boost employer brand reach and authenticity." | |
] | |
}) | |
# --- Client Initialization --- | |
client = None | |
if GEMINI_API_KEY and GENAI_AVAILABLE: | |
try: | |
# This is specific. If using google-generativeai, this would be genai.configure(api_key=...) | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
logging.info("Google GenAI client initialized successfully (using genai.Client).") | |
except Exception as e: | |
logging.error(f"Failed to initialize Google GenAI client (using genai.Client): {e}") | |
client = None | |
else: | |
if not GEMINI_API_KEY: | |
logging.warning("GEMINI_API_KEY environment variable not set.") | |
if not GENAI_AVAILABLE: | |
logging.warning("Google GenAI library not available.") | |
# --- Utility function to get DataFrame schema representation --- | |
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str: | |
"""Generates a string representation of a DataFrame's schema and a small sample.""" | |
if not isinstance(df, pd.DataFrame): | |
return f"Item '{df_name}' is not a DataFrame.\n" | |
if df.empty: | |
return f"DataFrame '{df_name}': Empty\n" | |
# Define system columns to exclude from schema representation | |
system_columns = ['Created Date', 'Modified Date', '_id'] | |
# Filter out system columns for schema representation | |
filtered_columns = [col for col in df.columns if col not in system_columns] | |
schema_parts = [f"DataFrame '{df_name}':"] | |
schema_parts.append(f" Shape: {df.shape}") | |
schema_parts.append(" Columns:") | |
# Show only filtered columns in schema | |
for col in filtered_columns: | |
col_type = str(df[col].dtype) | |
null_count = df[col].isnull().sum() | |
unique_count = df[col].nunique() | |
schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})") | |
# Add note if system columns were excluded | |
excluded_columns = [col for col in df.columns if col in system_columns] | |
if excluded_columns: | |
schema_parts.append(f" Note: System columns excluded from display: {', '.join(excluded_columns)}") | |
if not df.empty and filtered_columns: | |
schema_parts.append(" Sample Data (first 2 rows):") | |
try: | |
# Create sample with only filtered columns | |
sample_df = df[filtered_columns].head(2) | |
sample_df_str = sample_df.to_string(index=True, max_colwidth=50) | |
indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')]) | |
schema_parts.append(indented_sample_df) | |
except Exception as e: | |
schema_parts.append(f" Could not generate sample data: {e}") | |
elif not df.empty and not filtered_columns: | |
schema_parts.append(" Sample Data: Only system columns present, no business data to display") | |
return "\n".join(schema_parts) + "\n" | |
def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str: | |
"""Generates a string representation of all DataFrame schemas.""" | |
if not dataframes: | |
return "No DataFrames available to the agent." | |
full_representation = ["=== Available DataFrame Schemas for Analysis ==="] | |
for name, df_instance in dataframes.items(): | |
full_representation.append(get_df_schema_representation(df_instance, name)) | |
return "\n".join(full_representation) | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy() | |
# Ensure 'text' column exists | |
if 'text' not in self.documents_df.columns and not self.documents_df.empty: | |
logging.warning("'text' column not found in RAG documents. RAG might not work.") | |
# Create an empty text column if df is not empty but lacks it, to prevent errors later | |
self.documents_df['text'] = "" | |
self.embedding_model_name = embedding_model_name # e.g., "models/text-embedding-004" or just "text-embedding-004" | |
self.embeddings: Optional[np.ndarray] = None | |
self.is_initialized = False | |
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}") | |
def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]: | |
if not client: | |
raise ConnectionError("GenAI client not initialized for RAG embedding.") | |
if not text or not isinstance(text, str): | |
logging.warning("Cannot embed empty or non-string text for RAG.") | |
return None | |
try: | |
embed_config_payload = None | |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): | |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") | |
response = client.models.embed_content( | |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name, | |
contents=text, | |
config=embed_config_payload | |
) | |
# Debug logging to understand the response structure | |
logging.info(f"Embedding response type: {type(response)}") | |
logging.info(f"Response attributes: {dir(response)}") | |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0: | |
embedding_obj = response.embeddings[0] | |
logging.info(f"Embedding object type: {type(embedding_obj)}") | |
logging.info(f"Embedding object attributes: {dir(embedding_obj)}") | |
# Try to extract values | |
if hasattr(embedding_obj, 'values'): | |
logging.info(f"Found 'values' attribute with type: {type(embedding_obj.values)}") | |
return np.array(embedding_obj.values) | |
elif hasattr(embedding_obj, 'embedding'): | |
logging.info(f"Found 'embedding' attribute with type: {type(embedding_obj.embedding)}") | |
return np.array(embedding_obj.embedding) | |
else: | |
logging.error(f"ContentEmbedding object has no 'values' or 'embedding' attribute") | |
logging.error(f"Available attributes: {[attr for attr in dir(embedding_obj) if not attr.startswith('_')]}") | |
return None | |
else: | |
logging.error(f"Unexpected response structure") | |
return None | |
except Exception as e: | |
logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True) | |
raise | |
async def initialize_embeddings(self): | |
if self.documents_df.empty or 'text' not in self.documents_df.columns: | |
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.") | |
self.embeddings = np.array([]) | |
self.is_initialized = True # Initialized, but with no embeddings | |
return | |
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): # Check if standard genai can be used | |
logging.error("GenAI client not available for RAG embedding initialization.") | |
self.embeddings = np.array([]) | |
return | |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...") | |
embedded_docs_list = [] | |
for index, row in self.documents_df.iterrows(): | |
text_to_embed = row.get('text', '') | |
if not text_to_embed or not isinstance(text_to_embed, str): | |
logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.") | |
continue | |
try: | |
# Use asyncio.to_thread for the synchronous embedding call | |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed) | |
if embedding_array is not None and embedding_array.size > 0: | |
embedded_docs_list.append(embedding_array) | |
else: | |
logging.warning(f"Empty or failed embedding for RAG document at index {index}.") | |
except Exception as e: | |
logging.error(f"Error embedding RAG document at index {index}: {e}") | |
continue # Continue with other documents | |
if not embedded_docs_list: | |
self.embeddings = np.array([]) | |
logging.warning("No RAG documents were successfully embedded.") | |
else: | |
try: | |
# Ensure all embeddings have the same shape before vstack | |
first_shape = embedded_docs_list[0].shape | |
if not all(emb.shape == first_shape for emb in embedded_docs_list): | |
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.") | |
# Attempt to filter out malformed embeddings if possible, or fail | |
# For now, we'll fail stacking if shapes are inconsistent. | |
self.embeddings = np.array([]) | |
return # Exit if shapes are inconsistent | |
self.embeddings = np.vstack(embedded_docs_list) | |
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}") | |
except ValueError as ve: | |
logging.error(f"Error stacking embeddings (likely due to inconsistent shapes): {ve}") | |
self.embeddings = np.array([]) | |
self.is_initialized = True | |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray: | |
if embeddings_matrix.ndim == 1: # Handle case of single document embedding | |
embeddings_matrix = embeddings_matrix.reshape(1, -1) | |
if query_vector.ndim == 1: | |
query_vector = query_vector.reshape(1, -1) | |
if embeddings_matrix.size == 0 or query_vector.size == 0: | |
return np.array([]) | |
# Normalize embeddings_matrix rows | |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True) | |
# Add a small epsilon to avoid division by zero for zero vectors | |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0) | |
# Normalize query_vector | |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True) | |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0) | |
# Calculate dot product | |
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten() | |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str: | |
if not self.is_initialized: | |
logging.debug("RAG system not initialized. Cannot retrieve info.") | |
return "" | |
if self.embeddings is None or self.embeddings.size == 0: | |
logging.debug("RAG embeddings not available. Cannot retrieve info.") | |
return "" | |
if not query or not isinstance(query, str): | |
logging.debug("Empty or invalid query for RAG retrieval.") | |
return "" | |
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): | |
logging.error("GenAI client not available for RAG query embedding.") | |
return "" | |
try: | |
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) # Embed query | |
if query_vector is None or query_vector.size == 0: | |
logging.warning("Query vector embedding failed or is empty for RAG.") | |
return "" | |
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector) | |
if similarity_scores.size == 0: | |
return "" | |
relevant_indices = np.where(similarity_scores >= min_similarity)[0] | |
if len(relevant_indices) == 0: | |
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'") | |
return "" | |
# Get scores for relevant documents and sort | |
relevant_scores = similarity_scores[relevant_indices] | |
# Argsort returns indices to sort relevant_scores; apply to relevant_indices | |
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]] | |
top_indices = sorted_relevant_indices_of_original[:top_k] | |
context_parts = [] | |
if 'text' in self.documents_df.columns: | |
for i in top_indices: | |
if 0 <= i < len(self.documents_df): | |
context_parts.append(self.documents_df.iloc[i]['text']) | |
context = "\n\n---\n\n".join(context_parts) | |
logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'") | |
return context | |
except Exception as e: | |
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True) | |
return "" | |
class EmployerBrandingAgent: | |
def __init__(self, | |
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
rag_documents_df: Optional[pd.DataFrame] = None, | |
llm_model_name: str = LLM_MODEL_NAME, | |
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME, | |
generation_config_dict: Optional[Dict] = None, | |
safety_settings_list: Optional[List] = None): # safety_settings_list expects list of dicts or SafetySetting objects | |
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} # Deep copy | |
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy() | |
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name) | |
self.llm_model_name = llm_model_name | |
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS | |
# Ensure safety settings are in the correct format if using google-generativeai directly | |
self.safety_settings_list = [] | |
if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'): | |
for ss_dict in safety_settings_list: | |
try: | |
# Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC} | |
self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold'])) | |
except Exception as e: | |
logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}") | |
elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models) | |
self.safety_settings_list = safety_settings_list | |
self.chat_history: List[Dict[str, str]] = [] # Stores {"role": "user/model", "content": "..."} | |
self.is_ready = False | |
self.llm_model_instance = None # For google-generativeai | |
if GENAI_AVAILABLE and client is None and GEMINI_API_KEY: # If client.Client failed but standard genai can be used | |
try: | |
genai.configure(api_key=GEMINI_API_KEY) | |
self.llm_model_instance = genai.GenerativeModel(self.llm_model_name) | |
logging.info(f"Initialized GenerativeModel '{self.llm_model_name}' via google-generativeai.") | |
except Exception as e: | |
logging.error(f"Failed to initialize google-generativeai.GenerativeModel: {e}") | |
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}") | |
async def initialize(self) -> bool: | |
"""Initializes asynchronous components of the agent, primarily RAG embeddings.""" | |
try: | |
if not client and not self.llm_model_instance : # Check if any LLM access is configured | |
logging.error("Cannot initialize agent: GenAI client (client.Client or google.generativeai) not available/configured.") | |
return False | |
await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized | |
self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs) | |
logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}") | |
return True | |
except Exception as e: | |
logging.error(f"Error during EmployerBrandingAgent.initialize: {e}", exc_info=True) | |
self.is_ready = False | |
return False | |
def _get_dataframes_summary(self) -> str: | |
return get_all_schemas_representation(self.all_dataframes) | |
def _build_system_prompt(self) -> str: | |
""" | |
Builds a comprehensive and user-friendly system prompt for an Employer Branding AI Agent | |
tailored for HR professionals, emphasizing natural conversation and masking technical details. | |
""" | |
return textwrap.dedent(""" | |
You are a friendly and insightful Employer Branding Analyst AI, your dedicated partner in making LinkedIn data analysis accessible, actionable, and easy to understand for HR professionals. | |
Your role is to make LinkedIn data analysis feel like a helpful conversation, not a technical task. | |
## Your Core Responsibilities: | |
1. **Translate Data into Business Insights**: Convert complex LinkedIn metrics into clear, actionable employer branding strategies. | |
2. **Provide Context**: Always explain what metrics mean in HR terms (e.g., "An engagement rate of 5% means that for every 100 people who saw your post, 5 interacted with it. This is a good indicator of how compelling your content is."). | |
3. **Offer Practical Recommendations**: Give specific, implementable actions the HR team can take. | |
4. **Educate While Analyzing**: Help users understand LinkedIn analytics concepts as you provide insights, in simple terms. | |
## Communication Style: | |
- **Be approachable and conversational**: Think of yourself as a helpful colleague, ready to assist. Your tone should be encouraging and supportive. | |
- **Use HR-friendly language**: Avoid technical jargon. If an analytics term is necessary, explain it simply and immediately. | |
- **Ask clarifying questions naturally**: If you need more information to fulfill a request, phrase your questions in a business context. *Absolutely do not refer to DataFrame names, column names, or other technical data structures.* | |
- Instead of: "Which column has the date?" | |
- Ask: "For which period are you interested in seeing this data?" or "Are you looking for trends over specific months, or for the whole year?" | |
- Instead of: "Do you want 'follower_count_organic' or 'follower_count_paid'?" | |
- Ask: "When you say followers, are you thinking about the growth from our regular content, or from any specific paid campaigns, or perhaps a combined view?" | |
- **Structure responses clearly**: Use headers, bullet points, and numbered lists for easy scanning and digestion. | |
- **Provide context first**: Start with what the data means in practical terms before diving into recommendations. | |
- **Include confidence levels (subtly)**: When making recommendations, you can indicate certainty by saying things like "Based on the current data, a strong first step would be..." or "It's likely that X will improve Y, but we'd get a clearer picture with more data on Z." | |
- **Offer alternatives**: Provide multiple options when possible, explaining the potential upsides or considerations for each in plain language. | |
## When Analyzing Data: | |
- **Start with the "So What?"**: Always lead with the business impact or the 'why it matters' of your findings. | |
- **Use benchmarks (if available and relevant)**: Compare performance to industry standards if you have access to such benchmarks, explaining their relevance. | |
- **Identify trends**: Look for patterns over time and explain their significance for employer branding. | |
- **Highlight quick wins**: Point out easy improvements alongside longer-term strategies. | |
- **Consider resource constraints**: Acknowledge that HR teams often have limited time and budget when suggesting actions. | |
## When Processing Data Requests: | |
- **Work entirely behind the scenes**: You will internally query and analyze the provided data. *Never show or describe any code, internal queries, or technical data processing steps to the user.* Your internal workings should be invisible. | |
- **Present only the results**: Show findings, insights, and if helpful, simple descriptions of visualizations (e.g., "We saw a steady increase in X over the last quarter."). | |
- **Infer data needs from natural language**: Use the user's natural language and your understanding of HR goals to determine which data (e.g., from `follower_stats`, `posts`) and which specific fields (e.g., organic vs. paid followers, dates) are relevant for your internal analysis. | |
- **Use the exact DataFrame names** (like `follower_stats`, `posts`, `post_stats`, `mentions`) from the 'Available DataFrame Schemas' section for *your internal processing only*. These names are never to be mentioned to the user. | |
- **Handle data issues gracefully**: If data is missing, incomplete, or doesn't allow for a specific request, explain the limitations in business terms. For example: "I can show you the follower trends up to March 2025, as that's the latest information available," or "To look at X, I'd typically need Y type of information, which doesn't seem to be in the current data." | |
- **Create understandable summaries**: Describe trends and patterns in easy-to-understand formats. | |
- **Specific instructions for `follower_stats` DataFrame (if available) - *For your internal understanding and processing only*:** | |
- When the user asks about follower numbers or gains, you'll likely need `follower_stats` for your internal analysis. | |
- Remember that date information (formatted as strings "YYYY-MM-DD") is often in the `category_name` column. | |
- To get monthly follower gains, you'll internally filter where `follower_count_type` is `"follower_gains_monthly"`. | |
- The actual numeric follower count for that period will be in another column (e.g., 'follower_count_organic' or 'follower_count_paid'). | |
- *When you need to ask the user for clarification related to this data (e.g., about dates or types of followers), do so using general, HR-friendly questions as per the 'Communication Style' guidelines. For example, instead of mentioning `category_name` or `follower_count_type`, you might ask: "Are you interested in follower numbers for a specific month, or the overall trend for the year?" or "Are we looking at followers gained from our day-to-day content, or from specific promotional activities?"* | |
## Response Structure Guidelines: | |
1. **Friendly Opening & Executive Summary**: Start with a brief, friendly acknowledgement, then 2-3 key takeaways in simple terms. | |
2. **Data Insights**: What the numbers tell us (with context and HR relevance). | |
3. **Recommendations**: Specific actions to take, perhaps prioritized by likely impact or ease of implementation. | |
4. **Next Steps / Moving Forward**: Clear, actionable follow-up suggestions, or an invitation for further questions. | |
## When You Can't Help Directly: | |
- **Be transparent (but not technical)**: Clearly state what you can and cannot do based on the available data or your capabilities, without blaming the data. | |
- **Offer alternatives**: Suggest related analyses you *can* perform or other ways to approach their question. | |
- **Educate gently**: Explain (in simple terms) why certain analyses might require different types of information if it helps the user understand. | |
- **Guide next steps**: Help users understand how they might be able to get the information they need, if it's outside your current scope. | |
## Key Reminders: | |
- **Never fabricate data** or assume information that isn't present in the provided schemas. | |
- **Always validate your internal assumptions** against the available data structure. | |
- **Focus on actionable insights** over merely impressive-sounding metrics. | |
- **Remember your audience**: Explain concepts clearly, assuming no prior analytics expertise. | |
- **Prioritize clarity and usefulness** over technical sophistication in your responses. | |
- **Always prioritize a helpful, human-like interaction.** | |
Your ultimate goal is to be a trusted partner, empowering HR professionals to confidently make data-driven employer branding decisions by providing clear, friendly, and actionable insights, regardless of their technical background. | |
""").strip() | |
async def _generate_response(self, current_user_query: str) -> str: | |
""" | |
Generates a response from the LLM based on the current query, system prompts, | |
data summaries, RAG context, and the agent's chat history. | |
Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry. | |
""" | |
if not self.is_ready: | |
return "Agent is not ready. Please initialize." | |
if not client and not self.llm_model_instance: | |
return "Error: AI service is not available. Check API configuration." | |
try: | |
system_prompt_text = self._build_system_prompt() | |
data_summary_text = self._get_dataframes_summary() | |
rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25) | |
# Construct the messages for the LLM API call | |
llm_messages = [] | |
# 1. System-level instructions and context (as a first "user" turn) | |
initial_context_prompt = ( | |
f"{system_prompt_text}\n\n" | |
f"## Available Data Overview:\n{data_summary_text}\n\n" | |
f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n" | |
f"Given this context, please respond to the user queries that follow in the chat history." | |
) | |
llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]}) | |
# 2. Priming assistant message | |
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]}) | |
# 3. Append the actual conversation history (already includes the current user query) | |
for entry in self.chat_history: | |
llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]}) | |
# --- Make the API call --- | |
response_text = "" | |
if self.llm_model_instance: # Standard google-generativeai usage | |
logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}") | |
# Prepare generation config and safety settings for google-generativeai | |
gen_config_payload = self.generation_config_dict | |
safety_settings_payload = self.safety_settings_list | |
if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig): | |
try: | |
gen_config_payload = types.GenerationConfig(**self.generation_config_dict) | |
except Exception as e: | |
logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}") | |
api_response = await self.llm_model_instance.generate_content_async( | |
contents=llm_messages, | |
generation_config=gen_config_payload, | |
safety_settings=safety_settings_payload | |
) | |
response_text = api_response.text | |
elif client: # google.genai client usage | |
logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}") | |
# Convert messages to the format expected by google.genai client | |
# The client expects a simpler contents format | |
contents = [] | |
for msg in llm_messages: | |
if msg["role"] == "user": | |
contents.append(msg["parts"][0]["text"]) | |
elif msg["role"] == "model": | |
# For model responses, we might need to handle differently | |
# but for now, let's include them as context | |
contents.append(f"Assistant: {msg['parts'][0]['text']}") | |
# Create the config object with both generation config and safety settings | |
config_dict = {} | |
# Add generation config parameters | |
if self.generation_config_dict: | |
for key, value in self.generation_config_dict.items(): | |
config_dict[key] = value | |
# Add safety settings | |
if self.safety_settings_list: | |
# Convert safety settings to the correct format if needed | |
safety_settings = [] | |
for ss in self.safety_settings_list: | |
if isinstance(ss, dict): | |
# Convert dict to types.SafetySetting | |
safety_settings.append(types.SafetySetting( | |
category=ss.get('category'), | |
threshold=ss.get('threshold') | |
)) | |
else: | |
safety_settings.append(ss) | |
config_dict['safety_settings'] = safety_settings | |
# Create the config object | |
config = types.GenerateContentConfig(**config_dict) | |
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name | |
api_response = await asyncio.to_thread( | |
client.models.generate_content, | |
model=model_path, | |
contents=contents, # Simplified contents format | |
config=config # Using config parameter instead of separate generation_config and safety_settings | |
) | |
# Parse response from client.models structure | |
if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts: | |
response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')] | |
response_text = "".join(response_text_parts).strip() | |
else: | |
# Handle blocked or empty responses | |
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason: | |
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}") | |
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}" | |
if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'): | |
finish_reason = api_response.candidates[0].finish_reason | |
if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP: | |
logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}") | |
return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing." | |
return "I apologize, but I couldn't generate a response from client.models." | |
else: | |
raise ConnectionError("No valid LLM client or model instance available.") | |
return response_text.strip() | |
except Exception as e: | |
error_message = str(e).lower() | |
# Check if it's a blocked prompt error by examining the error message | |
if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']): | |
logging.error(f"Blocked prompt from LLM: {e}", exc_info=True) | |
return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}" | |
else: | |
logging.error(f"Error in _generate_response: {e}", exc_info=True) | |
return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}" | |
def _validate_query(self, query: str) -> bool: | |
if not query or not isinstance(query, str) or len(query.strip()) < 3: | |
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'") | |
return False | |
if len(query) > 3000: # Increased limit slightly | |
logging.warning(f"Invalid query: too long. Length: {len(query)}") | |
return False | |
return True | |
async def process_query(self, user_query: str) -> str: | |
""" | |
Processes the user's query. | |
It relies on self.chat_history being set externally (by app.py) to include the full | |
conversation context, including the current user_query as the last "user" message. | |
This method then calls _generate_response to get the AI's reply. | |
It does NOT modify self.chat_history itself; app.py is responsible for that based on Gradio state. | |
""" | |
if not self._validate_query(user_query): | |
# This user_query is the one from Gradio input, also the last one in self.chat_history | |
return "Please provide a valid query (3 to 3000 characters)." | |
if not self.is_ready: | |
logging.warning("process_query called but agent is not ready. Attempting re-initialization.") | |
# This is a fallback. Ideally, initialize is called once and confirmed. | |
init_success = await self.initialize() | |
if not init_success: | |
return "The agent is not properly initialized and could not be started. Please check configuration and logs." | |
# user_query is the current text from the input box. | |
# self.chat_history (set by app.py) should already contain this user_query as the last message. | |
# We pass user_query to _generate_response primarily for RAG context retrieval for the current turn. | |
response_text = await self._generate_response(user_query) | |
return response_text | |
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]): | |
"""Updates the agent's DataFrames. Does not automatically re-initialize RAG or LLM.""" | |
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} # Deep copy | |
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}") | |
# Note: If RAG documents depend on these DataFrames, RAG might need re-initialization. | |
# For now, RAG uses a static document set. | |
def clear_chat_history(self): | |
"""Clears the agent's internal chat history. App.py should also clear Gradio state.""" | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent internal chat history cleared.") | |
def get_status(self) -> Dict[str, Any]: | |
return { | |
"is_ready": self.is_ready, | |
"has_api_key": bool(GEMINI_API_KEY), | |
"genai_available": GENAI_AVAILABLE, | |
"client_type": "genai.Client" if client else ("google-generativeai" if self.llm_model_instance else "None"), | |
"rag_initialized": self.rag_system.is_initialized, | |
"num_dataframes": len(self.all_dataframes), | |
"dataframe_keys": list(self.all_dataframes.keys()), | |
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0, | |
"llm_model_name": self.llm_model_name, | |
"embedding_model_name": self.embedding_model_name | |
} | |
# --- Functions for Gradio integration (if needed directly, but app.py handles instantiation) --- | |
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent: | |
logging.info("Creating new EmployerBrandingAgent instance via helper function.") | |
return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs) | |
async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool: | |
logging.info("Initializing agent via async helper function.") | |
return await agent.initialize() | |
if __name__ == "__main__": | |
async def test_agent_logic(): | |
print("--- Testing Employer Branding Agent ---") | |
if not GEMINI_API_KEY: | |
print("GEMINI_API_KEY not set. Skipping live API tests.") | |
return | |
sample_dfs = { | |
"followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}), | |
"posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]}) | |
} | |
# Test RAG document loading | |
custom_rag = pd.DataFrame({'text': ["Custom RAG context about LinkedIn engagement."]}) | |
agent = EmployerBrandingAgent( | |
all_dataframes=sample_dfs, | |
rag_documents_df=custom_rag, | |
llm_model_name=LLM_MODEL_NAME, | |
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME | |
) | |
print("Agent Status (pre-init):", agent.get_status()) | |
init_success = await agent.initialize() | |
print(f"Agent Initialization Success: {init_success}") | |
print("Agent Status (post-init):", agent.get_status()) | |
if not init_success: | |
print("Agent initialization failed. Cannot proceed with query test.") | |
return | |
# Simulate app.py setting history | |
test_query1 = "What are the key columns in my followers data?" | |
agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this | |
print(f"\nProcessing Query 1: '{test_query1}'") | |
response1 = await agent.process_query(user_query=test_query1) # Pass current query for RAG etc. | |
print(f"Agent Response 1:\n{response1}") | |
# Simulate app.py updating history for next turn | |
agent.chat_history.append({"role": "model", "content": response1}) | |
test_query2 = "Generate pandas code to get the total follower count." | |
agent.chat_history.append({"role": "user", "content": test_query2}) | |
print(f"\nProcessing Query 2: '{test_query2}'") | |
response2 = await agent.process_query(user_query=test_query2) | |
print(f"Agent Response 2:\n{response2}") | |
agent.chat_history.append({"role": "model", "content": response2}) | |
print("\nFinal Agent Chat History (internal):") | |
for item in agent.chat_history: | |
print(f"- {item['role']}: {item['content'][:100]}...") | |
print("\n--- Test Complete ---") | |
asyncio.run(test_agent_logic()) | |