Spaces:
Running
Running
| import pandas as pd | |
| import json | |
| import os | |
| import asyncio | |
| import logging | |
| import numpy as np | |
| import textwrap | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Union, Any, Tuple | |
| import traceback | |
| import pandasai as pai | |
| from pandasai_litellm import LiteLLM | |
| # Add this early, before matplotlib.pyplot is imported directly or by pandasai | |
| import matplotlib | |
| matplotlib.use('Agg') # Use a non-interactive backend for Matplotlib | |
| import matplotlib.pyplot as plt | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s') | |
| try: | |
| from google import genai | |
| from google.genai import types | |
| from google.genai import errors | |
| GENAI_AVAILABLE = True | |
| logging.info("Google Generative AI library imported successfully.") | |
| except ImportError: | |
| logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
| GENAI_AVAILABLE = False | |
| # Dummy classes for graceful degradation | |
| class genai: | |
| Client = None | |
| class types: | |
| EmbedContentConfig = None | |
| GenerationConfig = None | |
| SafetySetting = None | |
| Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) | |
| class HarmCategory: | |
| HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
| HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
| HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
| HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
| HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
| class HarmBlockThreshold: | |
| BLOCK_NONE = "BLOCK_NONE" | |
| BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
| BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
| BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" | |
| # --- Custom Exceptions --- | |
| class ValidationError(Exception): | |
| """Custom validation error for agent inputs""" | |
| pass | |
| class RateLimitError(Exception): | |
| """Placeholder for rate limit errors.""" | |
| pass | |
| class AgentNotReadyError(Exception): | |
| """Agent is not properly initialized""" | |
| pass | |
| # --- Configuration Constants --- | |
| GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
| LLM_MODEL_NAME = "gemini-2.5-flash-preview-05-20" | |
| GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" | |
| GENERATION_CONFIG_PARAMS = { | |
| "temperature": 0.7, | |
| "top_p": 0.95, | |
| "top_k": 40, | |
| "max_output_tokens": 8192, | |
| "candidate_count": 1, | |
| } | |
| DEFAULT_SAFETY_SETTINGS = [] | |
| # Default RAG documents | |
| DEFAULT_RAG_DOCUMENTS = pd.DataFrame({ | |
| 'text': [ | |
| "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.", | |
| "Key metrics for employer branding include employee engagement, candidate quality, and retention rates.", | |
| "LinkedIn is a crucial platform for showcasing company culture and attracting talent.", | |
| "Analyzing follower demographics and post engagement helps refine employer branding strategies.", | |
| "Content strategy should align with company values to attract the right talent.", | |
| "Employee advocacy programs can significantly boost employer brand reach and authenticity." | |
| ] | |
| }) | |
| # --- Client Initialization --- | |
| client = None | |
| if GEMINI_API_KEY and GENAI_AVAILABLE: | |
| try: | |
| client = genai.Client(api_key=GEMINI_API_KEY) | |
| logging.info("Google GenAI client initialized successfully.") | |
| except Exception as e: | |
| logging.error(f"Failed to initialize Google GenAI client: {e}") | |
| client = None | |
| else: | |
| if not GEMINI_API_KEY: | |
| logging.warning("GEMINI_API_KEY environment variable not set.") | |
| if not GENAI_AVAILABLE: | |
| logging.warning("Google GenAI library not available.") | |
| # --- Utility function to get DataFrame schema representation --- | |
| def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str: | |
| """Generates a string representation of a DataFrame's schema and a small sample.""" | |
| if not isinstance(df, pd.DataFrame): | |
| return f"Item '{df_name}' is not a DataFrame.\n" | |
| if df.empty: | |
| return f"DataFrame '{df_name}': Empty\n" | |
| # Define system columns to exclude from schema representation | |
| system_columns = ['Created Date', 'Modified Date', '_id'] | |
| # Filter out system columns for schema representation | |
| filtered_columns = [col for col in df.columns if col not in system_columns] | |
| schema_parts = [f"DataFrame '{df_name}':"] | |
| schema_parts.append(f" Shape: {df.shape}") | |
| schema_parts.append(" Columns:") | |
| # Show only filtered columns in schema | |
| for col in filtered_columns: | |
| col_type = str(df[col].dtype) | |
| null_count = df[col].isnull().sum() | |
| unique_count = df[col].nunique() | |
| schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})") | |
| # Add note if system columns were excluded | |
| excluded_columns = [col for col in df.columns if col in system_columns] | |
| if excluded_columns: | |
| schema_parts.append(f" Note: System columns excluded from display: {', '.join(excluded_columns)}") | |
| if not df.empty and filtered_columns: | |
| schema_parts.append(" Sample Data (first 2 rows):") | |
| try: | |
| # Create sample with only filtered columns | |
| sample_df = df[filtered_columns].head(2) | |
| sample_df_str = sample_df.to_string(index=True, max_colwidth=50) | |
| indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')]) | |
| schema_parts.append(indented_sample_df) | |
| except Exception as e: | |
| schema_parts.append(f" Could not generate sample data: {e}") | |
| elif not df.empty and not filtered_columns: | |
| schema_parts.append(" Sample Data: Only system columns present, no business data to display") | |
| return "\n".join(schema_parts) + "\n" | |
| def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str: | |
| """Generates a string representation of all DataFrame schemas.""" | |
| if not dataframes: | |
| return "No DataFrames available to the agent." | |
| full_representation = ["=== Available DataFrame Schemas for Analysis ==="] | |
| for name, df_instance in dataframes.items(): | |
| full_representation.append(get_df_schema_representation(df_instance, name)) | |
| return "\n".join(full_representation) | |
| class AdvancedRAGSystem: | |
| def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
| self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy() | |
| # Ensure 'text' column exists | |
| if 'text' not in self.documents_df.columns and not self.documents_df.empty: | |
| logging.warning("'text' column not found in RAG documents. RAG might not work.") | |
| self.documents_df['text'] = "" | |
| self.embedding_model_name = embedding_model_name | |
| self.embeddings: Optional[np.ndarray] = None | |
| self.is_initialized = False | |
| logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}") | |
| def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]: | |
| if not client: | |
| raise ConnectionError("GenAI client not initialized for RAG embedding.") | |
| if not text or not isinstance(text, str): | |
| logging.warning("Cannot embed empty or non-string text for RAG.") | |
| return None | |
| try: | |
| embed_config_payload = None | |
| if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): | |
| embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") | |
| response = client.models.embed_content( | |
| model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name, | |
| contents=text, # Fix: Remove the list wrapper | |
| config=embed_config_payload | |
| ) | |
| # Fix: Update response parsing - use .embeddings directly (it's a list) | |
| if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0: | |
| # Fix: Access embedding values directly from the list | |
| embedding_values = response.embeddings[0] # This is already the array/list of values | |
| return np.array(embedding_values) | |
| else: | |
| logging.error(f"Unexpected response structure") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True) | |
| raise | |
| async def initialize_embeddings(self): | |
| if self.documents_df.empty or 'text' not in self.documents_df.columns: | |
| logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.") | |
| self.embeddings = np.array([]) | |
| self.is_initialized = True | |
| return | |
| if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): | |
| logging.error("GenAI client not available for RAG embedding initialization.") | |
| self.embeddings = np.array([]) | |
| return | |
| logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...") | |
| embedded_docs_list = [] | |
| for index, row in self.documents_df.iterrows(): | |
| text_to_embed = row.get('text', '') | |
| if not text_to_embed or not isinstance(text_to_embed, str): | |
| logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.") | |
| continue | |
| try: | |
| embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed) | |
| if embedding_array is not None and embedding_array.size > 0: | |
| embedded_docs_list.append(embedding_array) | |
| else: | |
| logging.warning(f"Empty or failed embedding for RAG document at index {index}.") | |
| except Exception as e: | |
| logging.error(f"Error embedding RAG document at index {index}: {e}") | |
| continue | |
| if not embedded_docs_list: | |
| self.embeddings = np.array([]) | |
| logging.warning("No RAG documents were successfully embedded.") | |
| else: | |
| try: | |
| first_shape = embedded_docs_list[0].shape | |
| if not all(emb.shape == first_shape for emb in embedded_docs_list): | |
| logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.") | |
| self.embeddings = np.array([]) | |
| return | |
| self.embeddings = np.vstack(embedded_docs_list) | |
| logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}") | |
| except ValueError as ve: | |
| logging.error(f"Error stacking embeddings: {ve}") | |
| self.embeddings = np.array([]) | |
| self.is_initialized = True | |
| def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray: | |
| if embeddings_matrix.ndim == 1: | |
| embeddings_matrix = embeddings_matrix.reshape(1, -1) | |
| if query_vector.ndim == 1: | |
| query_vector = query_vector.reshape(1, -1) | |
| if embeddings_matrix.size == 0 or query_vector.size == 0: | |
| return np.array([]) | |
| norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True) | |
| normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0) | |
| norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True) | |
| normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0) | |
| return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten() | |
| async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str: | |
| if not self.is_initialized: | |
| logging.debug("RAG system not initialized. Cannot retrieve info.") | |
| return "" | |
| if self.embeddings is None or self.embeddings.size == 0: | |
| logging.debug("RAG embeddings not available. Cannot retrieve info.") | |
| return "" | |
| if not query or not isinstance(query, str): | |
| logging.debug("Empty or invalid query for RAG retrieval.") | |
| return "" | |
| if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): | |
| logging.error("GenAI client not available for RAG query embedding.") | |
| return "" | |
| try: | |
| query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) | |
| if query_vector is None or query_vector.size == 0: | |
| logging.warning("Query vector embedding failed or is empty for RAG.") | |
| return "" | |
| similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector) | |
| if similarity_scores.size == 0: | |
| return "" | |
| relevant_indices = np.where(similarity_scores >= min_similarity)[0] | |
| if len(relevant_indices) == 0: | |
| logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'") | |
| return "" | |
| relevant_scores = similarity_scores[relevant_indices] | |
| sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]] | |
| top_indices = sorted_relevant_indices_of_original[:top_k] | |
| context_parts = [] | |
| if 'text' in self.documents_df.columns: | |
| for i in top_indices: | |
| if 0 <= i < len(self.documents_df): | |
| context_parts.append(self.documents_df.iloc[i]['text']) | |
| context = "\n\n---\n\n".join(context_parts) | |
| logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'") | |
| return context | |
| except Exception as e: | |
| logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True) | |
| return "" | |
| class EmployerBrandingAgent: | |
| def __init__(self, | |
| all_dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
| rag_documents_df: Optional[pd.DataFrame] = None, | |
| llm_model_name: str = LLM_MODEL_NAME, | |
| embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME, | |
| generation_config_dict: Optional[Dict] = None, | |
| safety_settings_list: Optional[List] = None): | |
| self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} | |
| _rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy() | |
| self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name) | |
| self.llm_model_name = llm_model_name | |
| self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS | |
| self.safety_settings_list = safety_settings_list or DEFAULT_SAFETY_SETTINGS | |
| self.chat_history: List[Dict[str, str]] = [] | |
| self.is_ready = False | |
| # Create charts directory | |
| self.charts_dir = "./charts" | |
| os.makedirs(self.charts_dir, exist_ok=True) | |
| # Initialize PandasAI Agent | |
| self.pandas_agent = None | |
| self._initialize_pandas_agent() | |
| logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}") | |
| def _initialize_pandas_agent(self): | |
| """Initialize PandasAI with enhanced configuration""" | |
| if not self.all_dataframes or not GEMINI_API_KEY: | |
| logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key") | |
| return | |
| self._preprocess_dataframes_for_pandas_ai() | |
| try: | |
| # Configure LiteLLM with Gemini | |
| llm = LiteLLM( | |
| model="gemini/gemini-2.5-flash-preview-05-20", # Use gemini/ prefix for Gemini API | |
| api_key=GEMINI_API_KEY | |
| ) | |
| # Set PandasAI configuration | |
| pai.config.set({ | |
| "llm": llm, | |
| "temperature": 0.7, | |
| "verbose": True, | |
| "enable_cache": True, | |
| "save_charts": True, # Enable chart saving | |
| "save_charts_path": "./charts", # Directory to save charts | |
| "open_charts": False, # Don't auto-open charts in browser | |
| "custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"] # Allow plotting libraries | |
| }) | |
| # Store dataframes for chat queries (we'll use them directly) | |
| self.pandas_dfs = {} | |
| for name, df in self.all_dataframes.items(): | |
| # Convert to PandasAI DataFrame with description | |
| df_description = self._generate_dataframe_description(name, df) | |
| pandas_df = pai.DataFrame(df, description=df_description) | |
| self.pandas_dfs[name] = pandas_df | |
| self.pandas_agent = True # Flag to indicate PandasAI is ready | |
| logging.info(f"PandasAI initialized successfully with {len(self.pandas_dfs)} DataFrames") | |
| except Exception as e: | |
| logging.error(f"Failed to initialize PandasAI: {e}", exc_info=True) | |
| self.pandas_agent = None | |
| self.pandas_dfs = {} | |
| def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str: | |
| """Generate a descriptive summary for PandasAI to better understand the data""" | |
| description_parts = [f"This is the '{name}' dataset containing {len(df)} records."] | |
| # Add column descriptions based on common patterns | |
| column_descriptions = [] | |
| for col in df.columns: | |
| col_lower = col.lower() | |
| if 'date' in col_lower: | |
| column_descriptions.append(f"'{col}' contains date/time information") | |
| elif 'count' in col_lower or 'number' in col_lower: | |
| column_descriptions.append(f"'{col}' contains numerical count data") | |
| elif 'rate' in col_lower or 'percentage' in col_lower: | |
| column_descriptions.append(f"'{col}' contains rate/percentage metrics") | |
| elif 'follower' in col_lower: | |
| column_descriptions.append(f"'{col}' contains LinkedIn follower data") | |
| elif 'engagement' in col_lower: | |
| column_descriptions.append(f"'{col}' contains engagement metrics") | |
| elif 'post' in col_lower: | |
| column_descriptions.append(f"'{col}' contains post-related information") | |
| if column_descriptions: | |
| description_parts.append("Key columns: " + "; ".join(column_descriptions)) | |
| # Add specific context for employer branding | |
| # Special handling for follower_stats | |
| if name.lower() in ['follower_stats', 'followers']: | |
| description_parts.append("This data tracks LinkedIn company page follower growth and demographics. For monthly growth data, use the 'extracted_date' column for date-based queries instead of trying to cast 'category_name' as a date.") | |
| if 'extracted_date' in df.columns: | |
| description_parts.append("The 'extracted_date' column contains properly formatted dates (YYYY-MM-DD) extracted from category_name for follower_gains_monthly records.") | |
| elif name.lower() in ['posts', 'post_stats']: | |
| description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.") | |
| elif name.lower() in ['mentions', 'brand_mentions']: | |
| description_parts.append("This data tracks brand mentions and sentiment for employer branding reputation analysis.") | |
| return " ".join(description_parts) | |
| async def initialize(self) -> bool: | |
| """Initializes asynchronous components of the agent""" | |
| try: | |
| if not client: # Fix: Remove reference to llm_model_instance | |
| logging.error("Cannot initialize agent: GenAI client not available/configured.") | |
| return False | |
| await self.rag_system.initialize_embeddings() | |
| # Verify PandasAI agent is ready | |
| pandas_ready = self.pandas_agent is not None | |
| if not pandas_ready: | |
| logging.warning("PandasAI agent not initialized, attempting re-initialization") | |
| self._initialize_pandas_agent() | |
| pandas_ready = self.pandas_agent is not None | |
| self.is_ready = self.rag_system.is_initialized and pandas_ready | |
| logging.info(f"EnhancedEmployerBrandingAgent.initialize completed. RAG: {self.rag_system.is_initialized}, PandasAI: {pandas_ready}, Agent ready: {self.is_ready}") | |
| return self.is_ready | |
| except Exception as e: | |
| logging.error(f"Error during EnhancedEmployerBrandingAgent.initialize: {e}", exc_info=True) | |
| self.is_ready = False | |
| return False | |
| def _get_dataframes_summary(self) -> str: | |
| return get_all_schemas_representation(self.all_dataframes) | |
| def _preprocess_dataframes_for_pandas_ai(self): | |
| """Preprocess dataframes to handle date casting issues before PandasAI analysis""" | |
| if not self.all_dataframes: | |
| return | |
| for name, df in self.all_dataframes.items(): | |
| if name.lower() in ['follower_stats', 'followers']: | |
| # Create a copy to avoid modifying original data | |
| df_copy = df.copy() | |
| # Handle category_name column that contains dates for follower_gains_monthly | |
| if 'category_name' in df_copy.columns and 'follower_count_type' in df_copy.columns: | |
| # Create a proper date column for date-based queries | |
| def extract_date_from_category(row): | |
| if row.get('follower_count_type') == 'follower_gains_monthly': | |
| category_name = str(row.get('category_name', '')) | |
| # Check if it matches YYYY-MM-DD format | |
| import re | |
| date_pattern = r'^\d{4}-\d{2}-\d{2}$' | |
| if re.match(date_pattern, category_name): | |
| return category_name | |
| return None | |
| # Add extracted_date column for cleaner date operations | |
| df_copy['extracted_date'] = df_copy.apply(extract_date_from_category, axis=1) | |
| # Update the dataframe in our collection | |
| self.all_dataframes[name] = df_copy | |
| logging.info(f"Preprocessed {name} dataframe for date handling") | |
| def _build_system_prompt(self) -> str: | |
| """Enhanced system prompt that works with PandasAI integration""" | |
| return textwrap.dedent(""" | |
| You are a friendly and insightful Employer Branding Analyst AI, working as a dedicated partner for HR professionals to make LinkedIn data analysis accessible, actionable, and easy to understand. | |
| ## Your Enhanced Capabilities: | |
| You now have advanced data analysis capabilities through PandasAI integration, allowing you to: | |
| - Directly query and analyze DataFrames with natural language | |
| - Generate charts and visualizations automatically (ALWAYS create charts when data visualization would be helpful) | |
| - Perform complex statistical analysis on LinkedIn employer branding data | |
| - Handle multi-DataFrame queries and joins seamlessly | |
| ## Core Responsibilities: | |
| 1. **Intelligent Data Analysis**: Use your PandasAI integration to answer data questions directly and accurately | |
| 2. **Business Context Translation**: Convert technical analysis results into HR-friendly insights | |
| 3. **Actionable Recommendations**: Provide specific, implementable strategies based on data findings | |
| 4. **Educational Guidance**: Help users understand both the data insights and the LinkedIn analytics concepts | |
| ## CRITICAL COMMUNICATION RULES: | |
| - **NEVER show code, technical commands, or programming syntax** | |
| - **NEVER mention dataset names, column names, or technical data structure details** | |
| - **NEVER reference DataFrames, schemas, or database terminology** | |
| - **Always speak in business terms**: refer to "your LinkedIn data", "follower metrics", "engagement data", etc. | |
| - **Focus on insights, not methods**: explain what the data shows, not how it was analyzed | |
| ## Communication Style: | |
| - **Natural and Conversational**: Maintain a warm, supportive tone as a helpful colleague | |
| - **HR-Focused Language**: Avoid technical jargon; explain analytics terms in business context | |
| - **Context-Rich Responses**: Always explain what metrics mean for employer branding strategy | |
| - **Structured Insights**: Use clear formatting with headers, bullets, and logical flow | |
| ## Data Analysis Approach: | |
| When users ask data questions, you will: | |
| 1. **Leverage PandasAI**: Use your integrated data analysis capabilities to query the data directly | |
| 2. **Interpret Results**: Translate technical findings into business insights | |
| 3. **Add Context**: Combine data results with your RAG knowledge base for comprehensive answers | |
| 4. **Provide Recommendations**: Suggest specific actions based on the analysis | |
| ## Response Structure: | |
| 1. **Executive Summary**: Key findings in business terms | |
| 2. **Data Insights**: What the analysis reveals (charts/visualizations when helpful) | |
| 3. **Business Impact**: What this means for employer branding strategy | |
| 4. **Recommendations**: Specific, prioritized action items | |
| 5. **Next Steps**: Follow-up suggestions or questions | |
| ## Key Behaviors: | |
| - **Data-Driven**: Always ground insights in actual data analysis when possible | |
| - **Visual When Helpful**: Suggest or create charts that make data more understandable | |
| - **Proactive**: Identify related insights the user might find valuable | |
| - **Honest About Limitations**: Clearly state when data doesn't support certain analyses | |
| ## Example Language Patterns: | |
| - Instead of "DataFrame shows" → "Your LinkedIn data reveals" | |
| - Instead of "follower_count column" → "follower growth metrics" | |
| - Instead of "engagement_rate variable" → "post engagement performance" | |
| - Instead of "dataset analysis" → "performance review" | |
| Your goal remains to be a trusted partner, but now with powerful data analysis capabilities that enable deeper, more accurate insights for data-driven employer branding decisions. | |
| """).strip() | |
| def _classify_query_type(self, query: str) -> str: | |
| """Classify whether query needs data analysis, general advice, or both""" | |
| data_keywords = [ | |
| 'show', 'analyze', 'chart', 'graph', 'data', 'numbers', 'count', 'total', | |
| 'average', 'trend', 'compare', 'statistics', 'performance', 'metrics', | |
| 'followers', 'engagement', 'posts', 'growth', 'rate', 'percentage' | |
| ] | |
| advice_keywords = [ | |
| 'recommend', 'suggest', 'advice', 'strategy', 'improve', 'optimize', | |
| 'best practice', 'should', 'how to', 'what to do', 'tips' | |
| ] | |
| query_lower = query.lower() | |
| has_data_request = any(keyword in query_lower for keyword in data_keywords) | |
| has_advice_request = any(keyword in query_lower for keyword in advice_keywords) | |
| if has_data_request and has_advice_request: | |
| return "hybrid" | |
| elif has_data_request: | |
| return "data" | |
| elif has_advice_request: | |
| return "advice" | |
| else: | |
| return "general" | |
| # Replace the _generate_pandas_response method and everything after it with this properly indented code: | |
| async def _generate_pandas_response(self, query: str) -> tuple[str, bool]: | |
| """Generate response using PandasAI for data queries""" | |
| if not self.pandas_agent or not hasattr(self, 'pandas_dfs'): | |
| return "Data analysis not available - PandasAI not initialized.", False | |
| try: | |
| logging.info(f"Processing data query with PandasAI: {query[:100]}...") | |
| # Clear any existing matplotlib figures to avoid conflicts | |
| import matplotlib.pyplot as plt | |
| plt.clf() | |
| plt.close('all') | |
| # Use the first available dataframe for single-df queries | |
| if len(self.pandas_dfs) == 1: | |
| df = list(self.pandas_dfs.values())[0] | |
| logging.info(f"Using single DataFrame for query with shape: {df.df.shape}") | |
| pandas_response = df.chat(query) | |
| else: | |
| # For multiple dataframes, use pai.chat with all dfs | |
| dfs = list(self.pandas_dfs.values()) | |
| pandas_response = pai.chat(query, *dfs) | |
| # Handle different response types | |
| response_text = "" | |
| chart_info = "" | |
| # Check if response is a plot path or contains plot information | |
| if isinstance(pandas_response, str) and pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')): | |
| # Response is a chart path | |
| chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(pandas_response)}\nChart saved at: {pandas_response}" | |
| response_text = "Analysis completed with visualization" | |
| logging.info(f"Chart generated: {pandas_response}") | |
| elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path: | |
| # Response object has plot path | |
| chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(pandas_response.plot_path)}\nChart saved at: {pandas_response.plot_path}" | |
| response_text = getattr(pandas_response, 'text', str(pandas_response)) | |
| logging.info(f"Chart generated: {pandas_response.plot_path}") | |
| else: | |
| # Check for any new chart files in the charts directory | |
| if os.path.exists(self.charts_dir): | |
| chart_files = [f for f in os.listdir(self.charts_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.svg'))] | |
| if chart_files: | |
| # Get the most recent chart file | |
| chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(self.charts_dir, x)), reverse=True) | |
| latest_chart = chart_files[0] | |
| chart_path = os.path.join(self.charts_dir, latest_chart) | |
| # Check if this chart was created in the last 30 seconds (likely from this query) | |
| import time | |
| if time.time() - os.path.getmtime(chart_path) < 30: | |
| chart_info = f"\n\n📊 **Chart Generated**: {latest_chart}\nChart saved at: {chart_path}" | |
| logging.info(f"Chart generated: {chart_path}") | |
| # Handle text response | |
| if pandas_response and str(pandas_response).strip(): | |
| response_text = str(pandas_response).strip() | |
| else: | |
| response_text = "Analysis completed" | |
| final_response = response_text + chart_info | |
| return final_response, True | |
| except Exception as e: | |
| logging.error(f"Error in PandasAI processing: {e}", exc_info=True) | |
| # Try to provide a more helpful error message | |
| if "Invalid output" in str(e) and "plot save path" in str(e): | |
| return "I tried to create a visualization but encountered a formatting issue. Please try rephrasing your request or ask for specific data without requesting a chart.", False | |
| return f"Error processing data query: {str(e)}", False | |
| async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str: | |
| """Generate enhanced response combining PandasAI results with RAG context""" | |
| if not self.is_ready: | |
| return "Agent is not ready. Please initialize." | |
| if not client: | |
| return "Error: AI service is not available. Check API configuration." | |
| try: | |
| system_prompt = self._build_system_prompt() | |
| data_summary = self._get_dataframes_summary() | |
| rag_context = await self.rag_system.retrieve_relevant_info(query, top_k=2, min_similarity=0.25) | |
| # Build enhanced prompt based on query type and available results | |
| if query_type == "data" and pandas_result: | |
| enhanced_prompt = f""" | |
| {system_prompt} | |
| ## Data Analysis Context: | |
| {data_summary} | |
| ## PandasAI Analysis Result: | |
| {pandas_result} | |
| ## Additional Knowledge Context: | |
| {rag_context if rag_context else 'No additional context retrieved.'} | |
| ## User Query: | |
| {query} | |
| Please interpret the data analysis result above and provide business insights in a friendly, HR-focused manner. | |
| Explain what the findings mean for employer branding strategy and provide actionable recommendations. | |
| """ | |
| else: | |
| enhanced_prompt = f""" | |
| {system_prompt} | |
| ## Available Data Context: | |
| {data_summary} | |
| ## Knowledge Base Context: | |
| {rag_context if rag_context else 'No specific background information retrieved.'} | |
| ## User Query: | |
| {query} | |
| Please provide helpful insights and recommendations for this employer branding query. | |
| """ | |
| # Fix: Use only genai.Client approach - remove all google-generativeai code | |
| logging.debug(f"Using genai.Client for enhanced response generation") | |
| # Prepare config | |
| config_dict = dict(self.generation_config_dict) if self.generation_config_dict else {} | |
| if self.safety_settings_list: | |
| safety_settings = [] | |
| for ss in self.safety_settings_list: | |
| if isinstance(ss, dict): | |
| if GENAI_AVAILABLE and hasattr(types, 'SafetySetting'): | |
| safety_settings.append(types.SafetySetting( | |
| category=ss.get('category'), | |
| threshold=ss.get('threshold') | |
| )) | |
| else: | |
| safety_settings.append(ss) | |
| else: | |
| safety_settings.append(ss) | |
| config_dict['safety_settings'] = safety_settings | |
| if GENAI_AVAILABLE and hasattr(types, 'GenerateContentConfig'): | |
| config = types.GenerateContentConfig(**config_dict) | |
| else: | |
| config = config_dict | |
| model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name | |
| # Fix: Use synchronous call wrapped in asyncio.to_thread | |
| api_response = await asyncio.to_thread( | |
| client.models.generate_content, | |
| model=model_path, | |
| contents=enhanced_prompt, # Fix: Pass single prompt string instead of complex message format | |
| config=config | |
| ) | |
| # Fix: Updated response parsing | |
| if hasattr(api_response, 'candidates') and api_response.candidates: | |
| candidate = api_response.candidates[0] | |
| if hasattr(candidate, 'content') and candidate.content: | |
| if hasattr(candidate.content, 'parts') and candidate.content.parts: | |
| response_text_parts = [part.text for part in candidate.content.parts if hasattr(part, 'text')] | |
| response_text = "".join(response_text_parts).strip() | |
| else: | |
| response_text = str(candidate.content).strip() | |
| else: | |
| response_text = "" | |
| else: | |
| response_text = "" | |
| if not response_text: | |
| # Handle blocked or empty responses | |
| if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback: | |
| if hasattr(api_response.prompt_feedback, 'block_reason') and api_response.prompt_feedback.block_reason: | |
| logging.warning(f"Prompt blocked: {api_response.prompt_feedback.block_reason}") | |
| return f"I'm sorry, your request was blocked. Please try rephrasing your query." | |
| return "I couldn't generate a response. Please try rephrasing your query." | |
| return response_text | |
| except Exception as e: | |
| error_message = str(e).lower() | |
| if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']): | |
| logging.error(f"Blocked prompt: {e}") | |
| return "I'm sorry, your request was blocked by the safety filter. Please rephrase your query." | |
| else: | |
| logging.error(f"Error in _generate_enhanced_response: {e}", exc_info=True) | |
| return f"I encountered an error while processing your request: {str(e)}" | |
| def _validate_query(self, query: str) -> bool: | |
| """Validate user query input""" | |
| if not query or not isinstance(query, str) or len(query.strip()) < 3: | |
| logging.warning(f"Invalid query: too short or not a string. Query: '{query}'") | |
| return False | |
| if len(query) > 3000: | |
| logging.warning(f"Invalid query: too long. Length: {len(query)}") | |
| return False | |
| return True | |
| async def process_query(self, user_query: str) -> Dict[str, Optional[str]]: | |
| """ | |
| Main method to process user queries. | |
| Returns a dictionary: {"text": llm_response_string, "image_path": path_to_chart_or_none} | |
| """ | |
| if not self._validate_query(user_query): | |
| return {"text": "Please provide a valid query (3 to 3000 characters).", "image_path": None} | |
| if not self.is_ready: | |
| logging.warning("process_query called but agent is not ready. Attempting re-initialization.") | |
| init_success = await self.initialize() | |
| if not init_success: | |
| return {"text": "The agent is not properly initialized and could not be started. Please check configuration and logs.", "image_path": None} | |
| try: | |
| query_type = self._classify_query_type(user_query) | |
| logging.info(f"Query classified as: {query_type}") | |
| pandas_text_output: Optional[str] = None | |
| pandas_chart_path: Optional[str] = None | |
| pandas_success = False # Flag to track if PandasAI ran successfully | |
| # For data-related queries, try PandasAI first | |
| if query_type in ["data", "hybrid"] and self.pandas_agent: | |
| logging.info("Attempting PandasAI analysis...") | |
| pandas_text_output, pandas_success = await self._generate_pandas_response(user_query) | |
| if pandas_success: | |
| logging.info(f"PandasAI analysis successful. Text: '{str(pandas_text_output)[:100]}...'") | |
| # Check for chart generation in response | |
| if "Chart Generated" in pandas_text_output: | |
| # Extract chart path from response if present | |
| lines = pandas_text_output.split('\n') | |
| for line in lines: | |
| if "Chart saved at:" in line: | |
| pandas_chart_path = line.split("Chart saved at: ")[1].strip() | |
| break | |
| else: | |
| # pandas_text_output might contain the error message from PandasAI | |
| logging.warning(f"PandasAI analysis failed or returned no specific result. Message from PandasAI: {pandas_text_output}") | |
| # Prepare the context from PandasAI for the LLM | |
| llm_context_from_pandas = "" | |
| if pandas_text_output: # This could be a success message or an error message from PandasAI | |
| llm_context_from_pandas += f"Data Analysis Tool Output: {pandas_text_output}\n" | |
| if pandas_chart_path and pandas_success: # Only mention chart path if PandasAI was successful | |
| llm_context_from_pandas += f"[A chart has been generated by the data tool and saved at '{pandas_chart_path}'. You should refer to this chart in your explanation if it's relevant to the user's query.]\n" | |
| elif query_type in ["data", "hybrid"] and not self.pandas_agent: | |
| llm_context_from_pandas += "Note: The data analysis tool is currently unavailable.\n" | |
| # Always call the LLM to formulate the final response | |
| final_llm_response = await self._generate_enhanced_response( | |
| query=user_query, | |
| pandas_result=llm_context_from_pandas, # Pass the textual summary from PandasAI | |
| query_type=query_type | |
| ) | |
| # Return the LLM's response and the chart path if PandasAI was successful and generated one. | |
| # If PandasAI failed, pandas_chart_path would be None. | |
| # The final_llm_response should ideally explain any failures if pandas_text_output contained an error. | |
| return {"text": final_llm_response, "image_path": pandas_chart_path if pandas_success else None} | |
| except Exception as e: | |
| logging.error(f"Critical error in process_query: {e}", exc_info=True) | |
| return {"text": f"I encountered a critical error while processing your request: {type(e).__name__}. Please check the logs.", "image_path": None} | |
| def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]): | |
| """Updates the agent's DataFrames and reinitializes PandasAI agent""" | |
| self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} | |
| logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}") | |
| # Reinitialize PandasAI agent with new data | |
| self._initialize_pandas_agent() | |
| # Note: RAG system uses static documents and doesn't need reinitialization | |
| def update_rag_documents(self, new_rag_df: pd.DataFrame): | |
| """Updates RAG documents and reinitializes embeddings""" | |
| self.rag_system.documents_df = new_rag_df.copy() | |
| logging.info(f"RAG documents updated. Count: {len(new_rag_df)}") | |
| # Note: Embeddings will need to be reinitialized - call initialize() after this | |
| def clear_chat_history(self): | |
| """Clears the agent's internal chat history""" | |
| self.chat_history = [] | |
| logging.info("EmployerBrandingAgent internal chat history cleared.") | |
| def get_status(self) -> Dict[str, Any]: | |
| """Returns comprehensive status information about the agent""" | |
| return { | |
| "is_ready": self.is_ready, | |
| "has_api_key": bool(GEMINI_API_KEY), | |
| "genai_available": GENAI_AVAILABLE, | |
| "client_type": "genai.Client" if client else "None", # Fix: Remove reference to llm_model_instance | |
| "rag_initialized": self.rag_system.is_initialized, | |
| "pandas_agent_ready": self.pandas_agent is not None, | |
| "num_dataframes": len(self.all_dataframes), | |
| "dataframe_keys": list(self.all_dataframes.keys()), | |
| "num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0, | |
| "llm_model_name": self.llm_model_name, | |
| "embedding_model_name": self.rag_system.embedding_model_name, | |
| "chat_history_length": len(self.chat_history), | |
| "charts_save_path_pandasai": pai.config.save_charts_path if pai.config.llm else "PandasAI not configured" | |
| } | |
| def get_available_analyses(self) -> List[str]: | |
| """Returns list of suggested analyses based on available data""" | |
| if not self.all_dataframes: | |
| return ["No data available for analysis"] | |
| suggestions = [] | |
| for df_name, df in self.all_dataframes.items(): | |
| if 'follower' in df_name.lower(): | |
| suggestions.extend([ | |
| f"Show follower growth trends from {df_name}", | |
| f"Analyze follower demographics in {df_name}", | |
| f"Compare follower engagement rates" | |
| ]) | |
| elif 'post' in df_name.lower(): | |
| suggestions.extend([ | |
| f"Analyze post performance metrics from {df_name}", | |
| f"Show best performing content types", | |
| f"Compare engagement across post categories" | |
| ]) | |
| elif 'mention' in df_name.lower(): | |
| suggestions.extend([ | |
| f"Analyze brand mention sentiment from {df_name}", | |
| f"Show mention volume trends", | |
| f"Identify top mention sources" | |
| ]) | |
| # Add general suggestions | |
| suggestions.extend([ | |
| "What are the key employer branding trends?", | |
| "How can I improve our LinkedIn presence?", | |
| "What content strategy should we adopt?", | |
| "How do we measure employer branding success?" | |
| ]) | |
| return suggestions[:10] # Limit to top 10 suggestions | |
| # --- Helper Functions for External Integration --- | |
| def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
| rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent: | |
| """Factory function to create a new agent instance""" | |
| logging.info("Creating new EnhancedEmployerBrandingAgent instance via helper function.") | |
| return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs) | |
| async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool: | |
| """Async helper to initialize an agent instance""" | |
| logging.info("Initializing agent via async helper function.") | |
| return await agent.initialize() | |
| def validate_dataframes(dataframes: Dict[str, pd.DataFrame]) -> Dict[str, List[str]]: | |
| """Validate dataframes for common issues and return validation report""" | |
| validation_report = {} | |
| for name, df in dataframes.items(): | |
| issues = [] | |
| if df.empty: | |
| issues.append("DataFrame is empty") | |
| # Check for required columns based on data type | |
| if 'follower' in name.lower(): | |
| required_cols = ['date', 'follower_count'] | |
| missing_cols = [col for col in required_cols if col not in df.columns] | |
| if missing_cols: | |
| issues.append(f"Missing expected columns for follower data: {missing_cols}") | |
| elif 'post' in name.lower(): | |
| required_cols = ['date', 'engagement'] | |
| missing_cols = [col for col in required_cols if col not in df.columns] | |
| if missing_cols: | |
| issues.append(f"Missing expected columns for post data: {missing_cols}") | |
| # Check for data quality issues | |
| if not df.empty: | |
| null_percentages = (df.isnull().sum() / len(df) * 100).round(2) | |
| high_null_cols = null_percentages[null_percentages > 50].index.tolist() | |
| if high_null_cols: | |
| issues.append(f"Columns with >50% null values: {high_null_cols}") | |
| validation_report[name] = issues | |
| return validation_report |