Spaces:
Running
Running
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap | |
from datetime import datetime | |
from typing import Dict, List, Optional, Union, Any | |
import traceback | |
from pandasai import Agent, SmartDataframe | |
from pandasai.llm import GoogleGemini | |
from pandasai.responses.response_parser import ResponseParser | |
from pandasai.middlewares.base import BaseMiddleware | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s') | |
try: | |
from google import genai | |
from google.genai import types | |
from google.genai import errors | |
GENAI_AVAILABLE = True | |
logging.info("Google Generative AI library imported successfully.") | |
except ImportError: | |
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
GENAI_AVAILABLE = False | |
# Dummy classes for graceful degradation | |
class genai: | |
Client = None | |
class types: | |
EmbedContentConfig = None | |
GenerationConfig = None | |
SafetySetting = None | |
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) | |
class HarmCategory: | |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: | |
BLOCK_NONE = "BLOCK_NONE" | |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" | |
# --- Custom Exceptions --- | |
class ValidationError(Exception): | |
"""Custom validation error for agent inputs""" | |
pass | |
class RateLimitError(Exception): | |
"""Placeholder for rate limit errors.""" | |
pass | |
class AgentNotReadyError(Exception): | |
"""Agent is not properly initialized""" | |
pass | |
# --- Configuration Constants --- | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
LLM_MODEL_NAME = "gemini-2.5-flash-preview-05-20" | |
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 8192, | |
"candidate_count": 1, | |
} | |
DEFAULT_SAFETY_SETTINGS = [] | |
# Default RAG documents | |
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({ | |
'text': [ | |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.", | |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.", | |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.", | |
"Analyzing follower demographics and post engagement helps refine employer branding strategies.", | |
"Content strategy should align with company values to attract the right talent.", | |
"Employee advocacy programs can significantly boost employer brand reach and authenticity." | |
] | |
}) | |
# --- Client Initialization --- | |
client = None | |
if GEMINI_API_KEY and GENAI_AVAILABLE: | |
try: | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
logging.info("Google GenAI client initialized successfully.") | |
except Exception as e: | |
logging.error(f"Failed to initialize Google GenAI client: {e}") | |
client = None | |
else: | |
if not GEMINI_API_KEY: | |
logging.warning("GEMINI_API_KEY environment variable not set.") | |
if not GENAI_AVAILABLE: | |
logging.warning("Google GenAI library not available.") | |
# --- Custom PandasAI Middleware for Better Integration --- | |
class EmployerBrandingMiddleware(BaseMiddleware): | |
"""Custom middleware to enhance PandasAI responses with HR context""" | |
def run(self, code: str, **kwargs) -> str: | |
"""Add HR-friendly context to generated code""" | |
# Add comments to make code more understandable | |
enhanced_code = f""" | |
# HR Analytics Query Processing | |
# This code analyzes your LinkedIn employer branding data | |
{code} | |
""" | |
return enhanced_code | |
# --- Utility function to get DataFrame schema representation --- | |
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str: | |
"""Generates a string representation of a DataFrame's schema and a small sample.""" | |
if not isinstance(df, pd.DataFrame): | |
return f"Item '{df_name}' is not a DataFrame.\n" | |
if df.empty: | |
return f"DataFrame '{df_name}': Empty\n" | |
# Define system columns to exclude from schema representation | |
system_columns = ['Created Date', 'Modified Date', '_id'] | |
# Filter out system columns for schema representation | |
filtered_columns = [col for col in df.columns if col not in system_columns] | |
schema_parts = [f"DataFrame '{df_name}':"] | |
schema_parts.append(f" Shape: {df.shape}") | |
schema_parts.append(" Columns:") | |
# Show only filtered columns in schema | |
for col in filtered_columns: | |
col_type = str(df[col].dtype) | |
null_count = df[col].isnull().sum() | |
unique_count = df[col].nunique() | |
schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})") | |
# Add note if system columns were excluded | |
excluded_columns = [col for col in df.columns if col in system_columns] | |
if excluded_columns: | |
schema_parts.append(f" Note: System columns excluded from display: {', '.join(excluded_columns)}") | |
if not df.empty and filtered_columns: | |
schema_parts.append(" Sample Data (first 2 rows):") | |
try: | |
# Create sample with only filtered columns | |
sample_df = df[filtered_columns].head(2) | |
sample_df_str = sample_df.to_string(index=True, max_colwidth=50) | |
indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')]) | |
schema_parts.append(indented_sample_df) | |
except Exception as e: | |
schema_parts.append(f" Could not generate sample data: {e}") | |
elif not df.empty and not filtered_columns: | |
schema_parts.append(" Sample Data: Only system columns present, no business data to display") | |
return "\n".join(schema_parts) + "\n" | |
def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str: | |
"""Generates a string representation of all DataFrame schemas.""" | |
if not dataframes: | |
return "No DataFrames available to the agent." | |
full_representation = ["=== Available DataFrame Schemas for Analysis ==="] | |
for name, df_instance in dataframes.items(): | |
full_representation.append(get_df_schema_representation(df_instance, name)) | |
return "\n".join(full_representation) | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy() | |
# Ensure 'text' column exists | |
if 'text' not in self.documents_df.columns and not self.documents_df.empty: | |
logging.warning("'text' column not found in RAG documents. RAG might not work.") | |
self.documents_df['text'] = "" | |
self.embedding_model_name = embedding_model_name | |
self.embeddings: Optional[np.ndarray] = None | |
self.is_initialized = False | |
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}") | |
def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]: | |
if not client: | |
raise ConnectionError("GenAI client not initialized for RAG embedding.") | |
if not text or not isinstance(text, str): | |
logging.warning("Cannot embed empty or non-string text for RAG.") | |
return None | |
try: | |
embed_config_payload = None | |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): | |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") | |
response = client.models.embed_content( | |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name, | |
contents=text, # Fix: Remove the list wrapper | |
config=embed_config_payload | |
) | |
# Fix: Update response parsing - use .embeddings directly (it's a list) | |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0: | |
# Fix: Access embedding values directly from the list | |
embedding_values = response.embeddings[0] # This is already the array/list of values | |
return np.array(embedding_values) | |
else: | |
logging.error(f"Unexpected response structure") | |
return None | |
except Exception as e: | |
logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True) | |
raise | |
async def initialize_embeddings(self): | |
if self.documents_df.empty or 'text' not in self.documents_df.columns: | |
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.") | |
self.embeddings = np.array([]) | |
self.is_initialized = True | |
return | |
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): | |
logging.error("GenAI client not available for RAG embedding initialization.") | |
self.embeddings = np.array([]) | |
return | |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...") | |
embedded_docs_list = [] | |
for index, row in self.documents_df.iterrows(): | |
text_to_embed = row.get('text', '') | |
if not text_to_embed or not isinstance(text_to_embed, str): | |
logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.") | |
continue | |
try: | |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed) | |
if embedding_array is not None and embedding_array.size > 0: | |
embedded_docs_list.append(embedding_array) | |
else: | |
logging.warning(f"Empty or failed embedding for RAG document at index {index}.") | |
except Exception as e: | |
logging.error(f"Error embedding RAG document at index {index}: {e}") | |
continue | |
if not embedded_docs_list: | |
self.embeddings = np.array([]) | |
logging.warning("No RAG documents were successfully embedded.") | |
else: | |
try: | |
first_shape = embedded_docs_list[0].shape | |
if not all(emb.shape == first_shape for emb in embedded_docs_list): | |
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.") | |
self.embeddings = np.array([]) | |
return | |
self.embeddings = np.vstack(embedded_docs_list) | |
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}") | |
except ValueError as ve: | |
logging.error(f"Error stacking embeddings: {ve}") | |
self.embeddings = np.array([]) | |
self.is_initialized = True | |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray: | |
if embeddings_matrix.ndim == 1: | |
embeddings_matrix = embeddings_matrix.reshape(1, -1) | |
if query_vector.ndim == 1: | |
query_vector = query_vector.reshape(1, -1) | |
if embeddings_matrix.size == 0 or query_vector.size == 0: | |
return np.array([]) | |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True) | |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0) | |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True) | |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0) | |
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten() | |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str: | |
if not self.is_initialized: | |
logging.debug("RAG system not initialized. Cannot retrieve info.") | |
return "" | |
if self.embeddings is None or self.embeddings.size == 0: | |
logging.debug("RAG embeddings not available. Cannot retrieve info.") | |
return "" | |
if not query or not isinstance(query, str): | |
logging.debug("Empty or invalid query for RAG retrieval.") | |
return "" | |
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): | |
logging.error("GenAI client not available for RAG query embedding.") | |
return "" | |
try: | |
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) | |
if query_vector is None or query_vector.size == 0: | |
logging.warning("Query vector embedding failed or is empty for RAG.") | |
return "" | |
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector) | |
if similarity_scores.size == 0: | |
return "" | |
relevant_indices = np.where(similarity_scores >= min_similarity)[0] | |
if len(relevant_indices) == 0: | |
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'") | |
return "" | |
relevant_scores = similarity_scores[relevant_indices] | |
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]] | |
top_indices = sorted_relevant_indices_of_original[:top_k] | |
context_parts = [] | |
if 'text' in self.documents_df.columns: | |
for i in top_indices: | |
if 0 <= i < len(self.documents_df): | |
context_parts.append(self.documents_df.iloc[i]['text']) | |
context = "\n\n---\n\n".join(context_parts) | |
logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'") | |
return context | |
except Exception as e: | |
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True) | |
return "" | |
class EnhancedEmployerBrandingAgent: | |
def __init__(self, | |
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
rag_documents_df: Optional[pd.DataFrame] = None, | |
llm_model_name: str = LLM_MODEL_NAME, | |
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME, | |
generation_config_dict: Optional[Dict] = None, | |
safety_settings_list: Optional[List] = None): | |
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} | |
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy() | |
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name) | |
self.llm_model_name = llm_model_name | |
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS | |
self.safety_settings_list = safety_settings_list or DEFAULT_SAFETY_SETTINGS | |
self.chat_history: List[Dict[str, str]] = [] | |
self.is_ready = False | |
# Initialize PandasAI Agent | |
self.pandas_agent = None | |
self._initialize_pandas_agent() | |
logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}") | |
def _initialize_pandas_agent(self): | |
"""Initialize PandasAI Agent with enhanced configuration""" | |
if not self.all_dataframes or not GEMINI_API_KEY: | |
logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key") | |
return | |
try: | |
# Convert DataFrames to SmartDataframes with descriptive names | |
smart_dfs = [] | |
for name, df in self.all_dataframes.items(): | |
# Add metadata to help PandasAI understand the data better | |
df_description = self._generate_dataframe_description(name, df) | |
smart_df = SmartDataframe( | |
df, | |
name=name, | |
description=df_description | |
) | |
smart_dfs.append(smart_df) | |
# Configure PandasAI with Gemini | |
pandas_llm = GoogleGemini( | |
api_token=GEMINI_API_KEY, | |
model=self.llm_model_name, | |
temperature=0.7, | |
top_p=0.95, | |
top_k=40, | |
max_output_tokens=4096 | |
) | |
# Create agent with enhanced configuration | |
self.pandas_agent = Agent( | |
dfs=smart_dfs, | |
config={ | |
"llm": pandas_llm, | |
"verbose": True, | |
"enable_cache": True, | |
"save_charts": True, | |
"save_charts_path": "charts/", | |
"custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"], | |
"middlewares": [EmployerBrandingMiddleware()], | |
"response_parser": ResponseParser, | |
"max_retries": 3, | |
"conversational": True | |
} | |
) | |
logging.info(f"PandasAI agent initialized successfully with {len(smart_dfs)} DataFrames") | |
except Exception as e: | |
logging.error(f"Failed to initialize PandasAI agent: {e}", exc_info=True) | |
self.pandas_agent = None | |
def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str: | |
"""Generate a descriptive summary for PandasAI to better understand the data""" | |
description_parts = [f"This is the '{name}' dataset containing {len(df)} records."] | |
# Add column descriptions based on common patterns | |
column_descriptions = [] | |
for col in df.columns: | |
col_lower = col.lower() | |
if 'date' in col_lower: | |
column_descriptions.append(f"'{col}' contains date/time information") | |
elif 'count' in col_lower or 'number' in col_lower: | |
column_descriptions.append(f"'{col}' contains numerical count data") | |
elif 'rate' in col_lower or 'percentage' in col_lower: | |
column_descriptions.append(f"'{col}' contains rate/percentage metrics") | |
elif 'follower' in col_lower: | |
column_descriptions.append(f"'{col}' contains LinkedIn follower data") | |
elif 'engagement' in col_lower: | |
column_descriptions.append(f"'{col}' contains engagement metrics") | |
elif 'post' in col_lower: | |
column_descriptions.append(f"'{col}' contains post-related information") | |
if column_descriptions: | |
description_parts.append("Key columns: " + "; ".join(column_descriptions)) | |
# Add specific context for employer branding | |
if name.lower() in ['follower_stats', 'followers']: | |
description_parts.append("This data tracks LinkedIn company page follower growth and demographics for employer branding analysis.") | |
elif name.lower() in ['posts', 'post_stats']: | |
description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.") | |
elif name.lower() in ['mentions', 'brand_mentions']: | |
description_parts.append("This data tracks brand mentions and sentiment for employer branding reputation analysis.") | |
return " ".join(description_parts) | |
async def initialize(self) -> bool: | |
"""Initializes asynchronous components of the agent""" | |
try: | |
if not client: # Fix: Remove reference to llm_model_instance | |
logging.error("Cannot initialize agent: GenAI client not available/configured.") | |
return False | |
await self.rag_system.initialize_embeddings() | |
# Verify PandasAI agent is ready | |
pandas_ready = self.pandas_agent is not None | |
if not pandas_ready: | |
logging.warning("PandasAI agent not initialized, attempting re-initialization") | |
self._initialize_pandas_agent() | |
pandas_ready = self.pandas_agent is not None | |
self.is_ready = self.rag_system.is_initialized and pandas_ready | |
logging.info(f"EnhancedEmployerBrandingAgent.initialize completed. RAG: {self.rag_system.is_initialized}, PandasAI: {pandas_ready}, Agent ready: {self.is_ready}") | |
return self.is_ready | |
except Exception as e: | |
logging.error(f"Error during EnhancedEmployerBrandingAgent.initialize: {e}", exc_info=True) | |
self.is_ready = False | |
return False | |
def _get_dataframes_summary(self) -> str: | |
return get_all_schemas_representation(self.all_dataframes) | |
def _build_system_prompt(self) -> str: | |
"""Enhanced system prompt that works with PandasAI integration""" | |
return textwrap.dedent(""" | |
You are a friendly and insightful Employer Branding Analyst AI, working as a dedicated partner for HR professionals to make LinkedIn data analysis accessible, actionable, and easy to understand. | |
## Your Enhanced Capabilities: | |
You now have advanced data analysis capabilities through PandasAI integration, allowing you to: | |
- Directly query and analyze DataFrames with natural language | |
- Generate charts and visualizations automatically | |
- Perform complex statistical analysis on LinkedIn employer branding data | |
- Handle multi-DataFrame queries and joins seamlessly | |
## Core Responsibilities: | |
1. **Intelligent Data Analysis**: Use your PandasAI integration to answer data questions directly and accurately | |
2. **Business Context Translation**: Convert technical analysis results into HR-friendly insights | |
3. **Actionable Recommendations**: Provide specific, implementable strategies based on data findings | |
4. **Educational Guidance**: Help users understand both the data insights and the LinkedIn analytics concepts | |
## Communication Style: | |
- **Natural and Conversational**: Maintain a warm, supportive tone as a helpful colleague | |
- **HR-Focused Language**: Avoid technical jargon; explain analytics terms in business context | |
- **Context-Rich Responses**: Always explain what metrics mean for employer branding strategy | |
- **Structured Insights**: Use clear formatting with headers, bullets, and logical flow | |
## Data Analysis Approach: | |
When users ask data questions, you will: | |
1. **Leverage PandasAI**: Use your integrated data analysis capabilities to query the data directly | |
2. **Interpret Results**: Translate technical findings into business insights | |
3. **Add Context**: Combine data results with your RAG knowledge base for comprehensive answers | |
4. **Provide Recommendations**: Suggest specific actions based on the analysis | |
## Response Structure: | |
1. **Executive Summary**: Key findings in business terms | |
2. **Data Insights**: What the analysis reveals (charts/visualizations when helpful) | |
3. **Business Impact**: What this means for employer branding strategy | |
4. **Recommendations**: Specific, prioritized action items | |
5. **Next Steps**: Follow-up suggestions or questions | |
## Key Behaviors: | |
- **Data-Driven**: Always ground insights in actual data analysis when possible | |
- **Visual When Helpful**: Suggest or create charts that make data more understandable | |
- **Proactive**: Identify related insights the user might find valuable | |
- **Honest About Limitations**: Clearly state when data doesn't support certain analyses | |
Your goal remains to be a trusted partner, but now with powerful data analysis capabilities that enable deeper, more accurate insights for data-driven employer branding decisions. | |
""").strip() | |
def _classify_query_type(self, query: str) -> str: | |
"""Classify whether query needs data analysis, general advice, or both""" | |
data_keywords = [ | |
'show', 'analyze', 'chart', 'graph', 'data', 'numbers', 'count', 'total', | |
'average', 'trend', 'compare', 'statistics', 'performance', 'metrics', | |
'followers', 'engagement', 'posts', 'growth', 'rate', 'percentage' | |
] | |
advice_keywords = [ | |
'recommend', 'suggest', 'advice', 'strategy', 'improve', 'optimize', | |
'best practice', 'should', 'how to', 'what to do', 'tips' | |
] | |
query_lower = query.lower() | |
has_data_request = any(keyword in query_lower for keyword in data_keywords) | |
has_advice_request = any(keyword in query_lower for keyword in advice_keywords) | |
if has_data_request and has_advice_request: | |
return "hybrid" | |
elif has_data_request: | |
return "data" | |
elif has_advice_request: | |
return "advice" | |
else: | |
return "general" | |
async def _generate_pandas_response(self, query: str) -> tuple[str, bool]: | |
"""Generate response using PandasAI for data queries""" | |
if not self.pandas_agent: | |
return "Data analysis not available - PandasAI agent not initialized.", False | |
try: | |
# Use PandasAI to analyze the data | |
logging.info(f"Processing data query with PandasAI: {query[:100]}...") | |
pandas_response = self.pandas_agent.chat(query) | |
# Check if response is meaningful | |
if pandas_response and str(pandas_response).strip(): | |
return str(pandas_response), True | |
else: | |
return "I couldn't generate a meaningful analysis for this query.", False | |
except Exception as e: | |
logging.error(f"Error in PandasAI processing: {e}", exc_info=True) | |
return f"I encountered an error while analyzing the data: {str(e)}", False | |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str: | |
"""Generate enhanced response combining PandasAI results with RAG context""" | |
if not self.is_ready: | |
return "Agent is not ready. Please initialize." | |
if not client: | |
return "Error: AI service is not available. Check API configuration." | |
try: | |
system_prompt = self._build_system_prompt() | |
data_summary = self._get_dataframes_summary() | |
rag_context = await self.rag_system.retrieve_relevant_info(query, top_k=2, min_similarity=0.25) | |
# Build enhanced prompt based on query type and available results | |
if query_type == "data" and pandas_result: | |
enhanced_prompt = f""" | |
{system_prompt} | |
## Data Analysis Context: | |
{data_summary} | |
## PandasAI Analysis Result: | |
{pandas_result} | |
## Additional Knowledge Context: | |
{rag_context if rag_context else 'No additional context retrieved.'} | |
## User Query: | |
{query} | |
Please interpret the data analysis result above and provide business insights in a friendly, HR-focused manner. | |
Explain what the findings mean for employer branding strategy and provide actionable recommendations. | |
""" | |
else: | |
enhanced_prompt = f""" | |
{system_prompt} | |
## Available Data Context: | |
{data_summary} | |
## Knowledge Base Context: | |
{rag_context if rag_context else 'No specific background information retrieved.'} | |
## User Query: | |
{query} | |
Please provide helpful insights and recommendations for this employer branding query. | |
""" | |
# Fix: Use only genai.Client approach - remove all google-generativeai code | |
logging.debug(f"Using genai.Client for enhanced response generation") | |
# Prepare config | |
config_dict = dict(self.generation_config_dict) if self.generation_config_dict else {} | |
if self.safety_settings_list: | |
safety_settings = [] | |
for ss in self.safety_settings_list: | |
if isinstance(ss, dict): | |
if GENAI_AVAILABLE and hasattr(types, 'SafetySetting'): | |
safety_settings.append(types.SafetySetting( | |
category=ss.get('category'), | |
threshold=ss.get('threshold') | |
)) | |
else: | |
safety_settings.append(ss) | |
else: | |
safety_settings.append(ss) | |
config_dict['safety_settings'] = safety_settings | |
if GENAI_AVAILABLE and hasattr(types, 'GenerateContentConfig'): | |
config = types.GenerateContentConfig(**config_dict) | |
else: | |
config = config_dict | |
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name | |
# Fix: Use synchronous call wrapped in asyncio.to_thread | |
api_response = await asyncio.to_thread( | |
client.models.generate_content, | |
model=model_path, | |
contents=enhanced_prompt, # Fix: Pass single prompt string instead of complex message format | |
config=config | |
) | |
# Fix: Updated response parsing | |
if hasattr(api_response, 'candidates') and api_response.candidates: | |
candidate = api_response.candidates[0] | |
if hasattr(candidate, 'content') and candidate.content: | |
if hasattr(candidate.content, 'parts') and candidate.content.parts: | |
response_text_parts = [part.text for part in candidate.content.parts if hasattr(part, 'text')] | |
response_text = "".join(response_text_parts).strip() | |
else: | |
response_text = str(candidate.content).strip() | |
else: | |
response_text = "" | |
else: | |
response_text = "" | |
if not response_text: | |
# Handle blocked or empty responses | |
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback: | |
if hasattr(api_response.prompt_feedback, 'block_reason') and api_response.prompt_feedback.block_reason: | |
logging.warning(f"Prompt blocked: {api_response.prompt_feedback.block_reason}") | |
return f"I'm sorry, your request was blocked. Please try rephrasing your query." | |
return "I couldn't generate a response. Please try rephrasing your query." | |
return response_text | |
except Exception as e: | |
error_message = str(e).lower() | |
if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']): | |
logging.error(f"Blocked prompt: {e}") | |
return "I'm sorry, your request was blocked by the safety filter. Please rephrase your query." | |
else: | |
logging.error(f"Error in _generate_enhanced_response: {e}", exc_info=True) | |
return f"I encountered an error while processing your request: {str(e)}" | |
def _validate_query(self, query: str) -> bool: | |
"""Validate user query input""" | |
if not query or not isinstance(query, str) or len(query.strip()) < 3: | |
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'") | |
return False | |
if len(query) > 3000: | |
logging.warning(f"Invalid query: too long. Length: {len(query)}") | |
return False | |
return True | |
async def process_query(self, user_query: str) -> str: | |
""" | |
Main method to process user queries with hybrid approach: | |
1. Classify query type (data/advice/hybrid) | |
2. Use PandasAI for data queries | |
3. Use enhanced LLM for interpretation and advice | |
4. Combine results for comprehensive responses | |
""" | |
if not self._validate_query(user_query): | |
return "Please provide a valid query (3 to 3000 characters)." | |
if not self.is_ready: | |
logging.warning("process_query called but agent is not ready. Attempting re-initialization.") | |
init_success = await self.initialize() | |
if not init_success: | |
return "The agent is not properly initialized and could not be started. Please check configuration and logs." | |
try: | |
# Classify the query type | |
query_type = self._classify_query_type(user_query) | |
logging.info(f"Query classified as: {query_type}") | |
pandas_result = "" | |
pandas_success = False | |
# For data-related queries, try PandasAI first | |
if query_type in ["data", "hybrid"] and self.pandas_agent: | |
logging.info("Attempting PandasAI analysis...") | |
pandas_result, pandas_success = await self._generate_pandas_response(user_query) | |
if pandas_success: | |
logging.info("PandasAI analysis successful") | |
# For pure data queries with successful analysis, we might return enhanced result | |
if query_type == "data": | |
enhanced_response = await self._generate_enhanced_response( | |
user_query, pandas_result, query_type | |
) | |
return enhanced_response | |
else: | |
logging.warning("PandasAI analysis failed, falling back to general response") | |
# For hybrid queries, advice queries, or when PandasAI fails | |
if query_type == "hybrid" and pandas_success: | |
# Combine PandasAI results with enhanced advice | |
enhanced_response = await self._generate_enhanced_response( | |
user_query, pandas_result, query_type | |
) | |
return enhanced_response | |
else: | |
# General advice or fallback response | |
enhanced_response = await self._generate_enhanced_response( | |
user_query, "", query_type | |
) | |
return enhanced_response | |
except Exception as e: | |
logging.error(f"Error in process_query: {e}", exc_info=True) | |
return f"I encountered an error while processing your request: {str(e)}" | |
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]): | |
"""Updates the agent's DataFrames and reinitializes PandasAI agent""" | |
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} | |
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}") | |
# Reinitialize PandasAI agent with new data | |
self._initialize_pandas_agent() | |
# Note: RAG system uses static documents and doesn't need reinitialization | |
def update_rag_documents(self, new_rag_df: pd.DataFrame): | |
"""Updates RAG documents and reinitializes embeddings""" | |
self.rag_system.documents_df = new_rag_df.copy() | |
logging.info(f"RAG documents updated. Count: {len(new_rag_df)}") | |
# Note: Embeddings will need to be reinitialized - call initialize() after this | |
def clear_chat_history(self): | |
"""Clears the agent's internal chat history""" | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent internal chat history cleared.") | |
def get_status(self) -> Dict[str, Any]: | |
"""Returns comprehensive status information about the agent""" | |
return { | |
"is_ready": self.is_ready, | |
"has_api_key": bool(GEMINI_API_KEY), | |
"genai_available": GENAI_AVAILABLE, | |
"client_type": "genai.Client" if client else "None", # Fix: Remove reference to llm_model_instance | |
"rag_initialized": self.rag_system.is_initialized, | |
"pandas_agent_ready": self.pandas_agent is not None, | |
"num_dataframes": len(self.all_dataframes), | |
"dataframe_keys": list(self.all_dataframes.keys()), | |
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0, | |
"llm_model_name": self.llm_model_name, | |
"embedding_model_name": self.rag_system.embedding_model_name, | |
"chat_history_length": len(self.chat_history) | |
} | |
def get_available_analyses(self) -> List[str]: | |
"""Returns list of suggested analyses based on available data""" | |
if not self.all_dataframes: | |
return ["No data available for analysis"] | |
suggestions = [] | |
for df_name, df in self.all_dataframes.items(): | |
if 'follower' in df_name.lower(): | |
suggestions.extend([ | |
f"Show follower growth trends from {df_name}", | |
f"Analyze follower demographics in {df_name}", | |
f"Compare follower engagement rates" | |
]) | |
elif 'post' in df_name.lower(): | |
suggestions.extend([ | |
f"Analyze post performance metrics from {df_name}", | |
f"Show best performing content types", | |
f"Compare engagement across post categories" | |
]) | |
elif 'mention' in df_name.lower(): | |
suggestions.extend([ | |
f"Analyze brand mention sentiment from {df_name}", | |
f"Show mention volume trends", | |
f"Identify top mention sources" | |
]) | |
# Add general suggestions | |
suggestions.extend([ | |
"What are the key employer branding trends?", | |
"How can I improve our LinkedIn presence?", | |
"What content strategy should we adopt?", | |
"How do we measure employer branding success?" | |
]) | |
return suggestions[:10] # Limit to top 10 suggestions | |
# --- Helper Functions for External Integration --- | |
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
rag_docs: Optional[pd.DataFrame] = None) -> EnhancedEmployerBrandingAgent: | |
"""Factory function to create a new agent instance""" | |
logging.info("Creating new EnhancedEmployerBrandingAgent instance via helper function.") | |
return EnhancedEmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs) | |
async def initialize_agent_async(agent: EnhancedEmployerBrandingAgent) -> bool: | |
"""Async helper to initialize an agent instance""" | |
logging.info("Initializing agent via async helper function.") | |
return await agent.initialize() | |
def validate_dataframes(dataframes: Dict[str, pd.DataFrame]) -> Dict[str, List[str]]: | |
"""Validate dataframes for common issues and return validation report""" | |
validation_report = {} | |
for name, df in dataframes.items(): | |
issues = [] | |
if df.empty: | |
issues.append("DataFrame is empty") | |
# Check for required columns based on data type | |
if 'follower' in name.lower(): | |
required_cols = ['date', 'follower_count'] | |
missing_cols = [col for col in required_cols if col not in df.columns] | |
if missing_cols: | |
issues.append(f"Missing expected columns for follower data: {missing_cols}") | |
elif 'post' in name.lower(): | |
required_cols = ['date', 'engagement'] | |
missing_cols = [col for col in required_cols if col not in df.columns] | |
if missing_cols: | |
issues.append(f"Missing expected columns for post data: {missing_cols}") | |
# Check for data quality issues | |
if not df.empty: | |
null_percentages = (df.isnull().sum() / len(df) * 100).round(2) | |
high_null_cols = null_percentages[null_percentages > 50].index.tolist() | |
if high_null_cols: | |
issues.append(f"Columns with >50% null values: {high_null_cols}") | |
validation_report[name] = issues | |
return validation_report |