Spaces:
Running
Running
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap | |
from datetime import datetime # Added for date calculations | |
try: | |
from google import genai | |
from google.genai import types # For GenerateContentConfig, SafetySetting, HarmCategory, HarmBlockThreshold etc. | |
except ImportError: | |
logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True) | |
# Define dummy classes/variables if import fails | |
class genai: Client = None # type: ignore | |
class types: # type: ignore | |
EmbedContentConfig = None | |
GenerateContentConfig = None | |
SafetySetting = None | |
# Define HarmCategory and HarmBlockThreshold as inner classes or attributes for the dummy types | |
class HarmCategory: # type: ignore | |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: # type: ignore | |
BLOCK_NONE = "BLOCK_NONE" | |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed | |
# --- Custom Exceptions --- | |
class ValidationError(Exception): | |
"""Custom validation error for agent inputs""" | |
pass | |
class RateLimitError(Exception): | |
"""Placeholder for rate limit errors.""" | |
pass | |
# --- Configuration Constants --- | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
if not GEMINI_API_KEY: | |
logging.warning("GEMINI_API_KEY environment variable not set. EB Agent will not function.") | |
LLM_MODEL_NAME = "gemini-1.5-flash-latest" | |
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 8192, | |
"candidate_count": 1, | |
} | |
# No safety settings by default as per user request | |
DEFAULT_SAFETY_SETTINGS = [] | |
logging.info("Default safety settings are now empty (no explicit client-side safety settings).") | |
df_rag_documents = pd.DataFrame({ | |
'text': [ | |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.", | |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.", | |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.", | |
"Analyzing follower demographics and post engagement helps refine employer branding strategies." | |
] | |
}) | |
# --- Client Initialization --- | |
client = None | |
if GEMINI_API_KEY and genai.Client: | |
try: | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
logging.info("Google GenAI client initialized successfully.") | |
except Exception as e: | |
logging.error(f"Failed to initialize Google GenAI client: {e}", exc_info=True) | |
else: | |
logging.warning("Google GenAI client could not be initialized (GEMINI_API_KEY missing or library import failed).") | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.documents_df = documents_df.copy() | |
self.embedding_model_name = embedding_model_name | |
self.embeddings: np.ndarray | None = None | |
logging.info(f"AdvancedRAGSystem initialized with embedding model: {self.embedding_model_name}") | |
def _embed_single_document_sync(self, text: str) -> np.ndarray: | |
if not client: | |
raise ConnectionError("GenAI client not initialized for RAG embedding.") | |
if not text or not isinstance(text, str): | |
raise ValueError("Cannot embed empty or non-string text.") | |
embed_config = None | |
if types and hasattr(types, 'EmbedContentConfig'): | |
embed_config = types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") | |
response = client.models.embed_content( | |
model=self.embedding_model_name, | |
contents=text, | |
config=embed_config | |
) | |
return np.array(response.embeddings) | |
async def initialize_embeddings(self): | |
if self.documents_df.empty: | |
logging.info("RAG documents DataFrame is empty. No embeddings to initialize.") | |
self.embeddings = np.array([]) | |
return | |
if not client: | |
logging.error("GenAI client not available for RAG embedding initialization.") | |
self.embeddings = np.array([]) | |
return | |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...") | |
embedded_docs_list = [] | |
for index, row in self.documents_df.iterrows(): | |
text_to_embed = row.get('text') | |
if not text_to_embed or not isinstance(text_to_embed, str): | |
logging.warning(f"Skipping document at index {index} due to invalid text: {text_to_embed}") | |
continue | |
try: | |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed) | |
embedded_docs_list.append(embedding_array) | |
except Exception as e: | |
logging.error(f"Error embedding document text (index {index}) '{str(text_to_embed)[:50]}...': {e}", exc_info=False) | |
if not embedded_docs_list: | |
self.embeddings = np.array([]) | |
logging.warning("No documents were successfully embedded for RAG.") | |
else: | |
try: | |
self.embeddings = np.vstack(embedded_docs_list) | |
logging.info(f"Successfully embedded {len(embedded_docs_list)} documents for RAG. Embedding matrix shape: {self.embeddings.shape}") | |
except ValueError as ve: | |
logging.error(f"Error stacking embeddings: {ve}. Check individual embedding errors.", exc_info=True) | |
self.embeddings = np.array([]) | |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray: | |
query_vector = query_vector.flatten() | |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True) | |
normalized_embeddings_matrix = embeddings_matrix / (norm_matrix + 1e-8) | |
norm_query = np.linalg.norm(query_vector) | |
normalized_query_vector = query_vector / (norm_query + 1e-8) | |
return np.dot(normalized_embeddings_matrix, normalized_query_vector) | |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str: | |
if self.embeddings is None or self.embeddings.size == 0 or self.documents_df.empty: | |
logging.debug("RAG system not initialized or no documents/embeddings available for retrieval.") | |
return "" | |
if not query or not isinstance(query, str): | |
logging.debug("Empty or invalid query for RAG retrieval.") | |
return "" | |
if not client: | |
logging.error("GenAI client not available for RAG query embedding.") | |
return "" | |
try: | |
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) | |
except Exception as e: | |
logging.error(f"Error embedding query '{str(query)[:50]}...': {e}", exc_info=False) | |
return "" | |
if query_vector.ndim == 0 or query_vector.size == 0: | |
logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}") | |
return "" | |
try: | |
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector) | |
if similarity_scores.size == 0: return "" | |
relevant_indices_after_threshold = np.where(similarity_scores >= min_similarity)[0] | |
if len(relevant_indices_after_threshold) == 0: | |
logging.debug(f"No documents met the minimum similarity threshold of {min_similarity} for query: {query[:50]}") | |
return "" | |
relevant_scores = similarity_scores[relevant_indices_after_threshold] | |
sorted_relevant_indices_local = np.argsort(relevant_scores)[::-1] | |
top_original_indices = relevant_indices_after_threshold[sorted_relevant_indices_local[:top_k]] | |
if len(top_original_indices) == 0: return "" | |
context_parts = [self.documents_df.iloc[i]['text'] for i in top_original_indices if 'text' in self.documents_df.columns] | |
context = "\n\n---\n\n".join(context_parts) | |
logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...") | |
return context | |
except Exception as e: | |
logging.error(f"Error during RAG retrieval (similarity/sorting): {e}", exc_info=True) | |
return "" | |
class EmployerBrandingAgent: | |
def __init__(self, | |
all_dataframes: dict, | |
rag_documents_df: pd.DataFrame, | |
llm_model_name: str, | |
embedding_model_name: str, | |
generation_config_dict: dict, | |
safety_settings_list: list, | |
force_sandbox: bool = False): | |
self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()} | |
self.schemas_representation = self._get_enhanced_schemas_representation() | |
self.chat_history = [] | |
self.llm_model_name = llm_model_name | |
self.generation_config_dict = generation_config_dict | |
# If an empty list is passed, it means no specific safety settings are enforced by the client. | |
self.safety_settings_list = safety_settings_list if safety_settings_list is not None else [] | |
self.embedding_model_name = embedding_model_name | |
self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name) | |
self.force_sandbox = force_sandbox | |
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. Safety settings count: {len(self.safety_settings_list)}") | |
def _get_date_range(self, df: pd.DataFrame) -> str: | |
for col in df.columns: | |
if pd.api.types.is_datetime64_any_dtype(df[col]): | |
try: | |
min_date = df[col].min() | |
max_date = df[col].max() | |
if pd.notna(min_date) and pd.notna(max_date): | |
return f"{min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}" | |
except Exception: pass | |
return "N/A" | |
def _calculate_growth_rate(self, df: pd.DataFrame) -> str: | |
logging.debug("_calculate_growth_rate is a placeholder.") | |
return "Growth rate calculation not implemented." | |
def _analyze_engagement_trends(self, df: pd.DataFrame) -> str: | |
logging.debug("_analyze_engagement_trends is a placeholder.") | |
return "Engagement trend analysis not implemented." | |
def _analyze_demographics(self, df: pd.DataFrame) -> str: | |
logging.debug("_analyze_demographics is a placeholder.") | |
return "Demographic analysis not implemented." | |
def _analyze_post_performance(self, df: pd.DataFrame) -> str: | |
logging.debug("_analyze_post_performance is a placeholder.") | |
return "Post performance analysis not implemented." | |
def _extract_content_themes(self, df: pd.DataFrame) -> str: | |
logging.debug("_extract_content_themes is a placeholder.") | |
return "Content theme extraction not implemented." | |
def _find_optimal_times(self, df: pd.DataFrame) -> str: | |
logging.debug("_find_optimal_times is a placeholder.") | |
return "Optimal posting time analysis not implemented." | |
def _calculate_key_metrics(self, df: pd.DataFrame, df_type: str) -> dict: | |
metrics = {} | |
if 'follower' in df_type.lower(): | |
metrics.update({'follower_growth_rate': self._calculate_growth_rate(df), 'engagement_trends': self._analyze_engagement_trends(df), 'demographic_distribution': self._analyze_demographics(df)}) | |
elif 'post' in df_type.lower(): | |
metrics.update({'post_performance': self._analyze_post_performance(df), 'content_themes': self._extract_content_themes(df), 'optimal_posting_times': self._find_optimal_times(df)}) | |
elif 'mention' in df_type.lower(): | |
metrics['mention_volume_trend'] = "Mention volume trend not implemented." | |
metrics['mention_sentiment_overview'] = "Mention sentiment overview not implemented." | |
if not metrics: | |
logging.debug(f"No specific key metrics defined for df_type: {df_type}") | |
return {"info": "Standard metrics applicable."} | |
return metrics | |
def _calculate_data_freshness(self, df: pd.DataFrame) -> str: | |
for col in df.columns: | |
if pd.api.types.is_datetime64_any_dtype(df[col]): | |
try: | |
max_date = df[col].max() | |
if pd.notna(max_date): | |
days_diff = (datetime.now(max_date.tzinfo if max_date.tzinfo else None) - max_date).days | |
return f"Data up to {max_date.strftime('%Y-%m-%d')} ({days_diff} days old)" | |
except Exception: pass | |
return "Freshness N/A (no clear date column)" | |
def _check_data_consistency(self, df: pd.DataFrame) -> str: | |
logging.debug("_check_data_consistency is a placeholder.") | |
return "Consistency checks not implemented." | |
def _identify_accuracy_issues(self, df: pd.DataFrame) -> str: | |
logging.debug("_identify_accuracy_issues is a placeholder.") | |
return "Accuracy issue identification not implemented." | |
def _assess_data_quality(self, df: pd.DataFrame) -> dict: | |
completeness = (1 - (df.isnull().sum().sum() / (len(df) * len(df.columns)))) if len(df) > 0 and len(df.columns) > 0 else 0 | |
return {'completeness_score': f"{completeness:.2%}", 'freshness_info': self._calculate_data_freshness(df), 'consistency_check': self._check_data_consistency(df), 'accuracy_flags_summary': self._identify_accuracy_issues(df), 'sample_size_notes': f"{len(df)} records. {'Adequate for basic analysis.' if len(df) >= 100 else 'Limited sample size; insights may be indicative.'}"} | |
def _identify_patterns(self, df: pd.DataFrame, key: str) -> str: | |
logging.debug(f"_identify_patterns for {key} is a placeholder.") | |
return "Pattern identification not implemented." | |
def _format_df_analysis(self, df_key: str, analysis: dict) -> str: | |
formatted_parts = [f"\n--- DataFrame: df_{df_key} ---", f" Shape: {analysis['shape']}", f" Date Range: {analysis['date_range']}", " Key Metrics:"] | |
for metric, value in analysis['key_metrics'].items(): formatted_parts.append(f" - {metric.replace('_', ' ').title()}: {value}") | |
formatted_parts.append(" Data Quality Assessment:") | |
for aspect, value in analysis['data_quality'].items(): formatted_parts.append(f" - {aspect.replace('_', ' ').title()}: {value}") | |
formatted_parts.append(f" Notable Patterns: {analysis['notable_patterns']}") | |
return "\n".join(formatted_parts) | |
def _get_enhanced_schemas_representation(self) -> str: | |
schema_descriptions = ["=== DETAILED LINKEDIN DATA OVERVIEW ==="] | |
if not self.all_dataframes: | |
schema_descriptions.append("No dataframes available for analysis.") | |
return "\n".join(schema_descriptions) | |
for key, df in self.all_dataframes.items(): | |
if df.empty: | |
schema_descriptions.append(f"\n--- DataFrame: df_{key} ---\nStatus: Empty. No analysis possible.") | |
continue | |
analysis = {'shape': df.shape, 'date_range': self._get_date_range(df), 'key_metrics': self._calculate_key_metrics(df, key), 'data_quality': self._assess_data_quality(df), 'notable_patterns': self._identify_patterns(df, key)} | |
schema_descriptions.append(self._format_df_analysis(key, analysis)) | |
return "\n".join(schema_descriptions) | |
def _extract_query_intent(self, query: str) -> str: | |
logging.debug("_extract_query_intent is a placeholder.") | |
if "compare" in query.lower() or "benchmark" in query.lower(): return "comparison" | |
if "trend" in query.lower(): return "trend_analysis" | |
return "general" | |
async def _get_business_context(self, intent: str) -> str: | |
logging.debug("_get_business_context is a placeholder.") | |
if intent == "comparison": return "Company is focused on outperforming competitors in tech hiring." | |
return "Company aims to improve overall employer brand perception." | |
async def _get_industry_benchmarks(self, intent: str) -> str: | |
logging.debug("_get_industry_benchmarks is a placeholder.") | |
if intent == "trend_analysis": return "Typical follower growth in this sector is 5-10% MoM." | |
return "Average engagement rate for similar companies is 2-3%." | |
async def _enhance_rag_context(self, query: str, base_context: str) -> str: | |
intent = self._extract_query_intent(query) | |
business_context_val = await self._get_business_context(intent) | |
benchmarks_val = await self._get_industry_benchmarks(intent) | |
enhanced_context = f"""{base_context} | |
--- ADDITIONAL CONTEXT FOR YOUR ANALYSIS --- | |
Business Focus: {business_context_val} | |
Relevant Benchmarks: {benchmarks_val}""" | |
return enhanced_context | |
async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str: | |
prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation] | |
if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0: | |
base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query) | |
if base_rag_context: | |
enhanced_rag_context = await self._enhance_rag_context(raw_user_query, base_rag_context) | |
prompt_parts.extend(["--- RELEVANT CONTEXTUAL INFORMATION (from documents & business knowledge) ---", enhanced_rag_context]) | |
prompt_parts.extend(["--- USER REQUEST ---", f"Based on all the information above, please respond to the following user query:\n{raw_user_query}"]) | |
final_prompt = "\n".join(prompt_parts) | |
logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}") | |
return final_prompt | |
async def _process_structured_query(self, prompt: str) -> dict: | |
logging.debug("_process_structured_query is a placeholder.") | |
return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]} | |
async def _generate_hr_insights(self, query: str, context: str) -> str: | |
insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..." | |
if not client: return "Error: AI client not configured for generating HR insights." | |
api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}] | |
api_safety_settings_objects = [] | |
# self.safety_settings_list is expected to be empty if no settings are desired | |
if types and hasattr(types, 'SafetySetting') and self.safety_settings_list: | |
for ss_item in self.safety_settings_list: | |
try: | |
api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold'])) | |
except Exception as e_ss: | |
logging.warning(f"Could not create SafetySetting object from {ss_item} for HR insights: {e_ss}. Using raw item.") | |
api_safety_settings_objects.append(ss_item) | |
elif self.safety_settings_list: # Fallback if types.SafetySetting not available but list is not empty | |
api_safety_settings_objects = self.safety_settings_list | |
api_generation_config_obj = None | |
if types and hasattr(types, 'GenerateContentConfig'): | |
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects) | |
else: # Fallback if types.GenerateContentConfig is not available | |
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects} | |
try: | |
response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj) | |
if not response.candidates: return "HR insights generation failed: No response from AI." | |
return response.text.strip() | |
except Exception as e: | |
logging.error(f"Error generating HR insights: {e}", exc_info=True) | |
return f"Error generating HR insights: {str(e)}" | |
def _validate_query(self, query: str) -> bool: | |
if not query or len(query.strip()) < 3: logging.warning(f"Query too short: '{query}'"); return False | |
hr_keywords = ['employee', 'talent', 'hiring', 'culture', 'brand', 'engagement', 'retention', 'follower', 'post', 'mention', 'linkedin'] | |
if not any(keyword in query.lower() for keyword in hr_keywords): logging.warning(f"Query may not be HR/LinkedIn-relevant: {query[:50]}") | |
return True | |
def _get_query_help_message(self) -> str: | |
return "I'm here to help with Employer Branding analysis... Example: 'What are the top industries of my followers?'" | |
async def _check_system_readiness(self) -> dict: | |
logging.debug("_check_system_readiness is a placeholder.") | |
if not client: return {'ready': False, 'reason': 'AI Client not initialized.'} | |
if self.rag_system.embeddings is None: logging.warning("RAG embeddings not yet initialized.") | |
return {'ready': True, 'reason': 'System appears ready.'} | |
def _get_fallback_response(self, query: str) -> str: | |
logging.error(f"Executing fallback response for query: {query[:50]}") | |
return "I encountered an unexpected issue..." | |
async def _core_query_processing(self, raw_user_query_this_turn: str) -> str: | |
augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn) | |
api_call_contents = list(self.chat_history) | |
api_call_contents.append({"role": "user", "parts": [{"text": augmented_current_user_prompt_text}]}) | |
logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}") | |
api_safety_settings_objects = [] | |
# self.safety_settings_list is expected to be empty if no settings are desired | |
if types and hasattr(types, 'SafetySetting') and self.safety_settings_list: | |
for ss_item in self.safety_settings_list: | |
try: | |
api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold'])) | |
except Exception as e_ss_core: | |
logging.warning(f"Could not create SafetySetting object from {ss_item} in core: {e_ss_core}. Using raw item.") | |
api_safety_settings_objects.append(ss_item) | |
elif self.safety_settings_list : # Fallback if types.SafetySetting not available but list is not empty | |
api_safety_settings_objects = self.safety_settings_list | |
api_generation_config_obj = None | |
if types and hasattr(types, 'GenerateContentConfig'): | |
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects) | |
else: # Fallback if types.GenerateContentConfig is not available | |
logging.error("GenerateContentConfig type not available. API call might fail.") | |
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects} | |
response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj) | |
if not response.candidates: | |
block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown" | |
block_message = response.prompt_feedback.block_reason_message if response.prompt_feedback else "" | |
error_message = f"The AI's response was blocked. Reason: {block_reason}." + (f" Details: {block_message}" if block_message else "") | |
return error_message | |
return response.text.strip() | |
async def _process_query_with_timeout(self, raw_user_query_this_turn: str, timeout_seconds: int = 60) -> str: | |
try: return await asyncio.wait_for(self._core_query_processing(raw_user_query_this_turn), timeout=timeout_seconds) | |
except asyncio.TimeoutError: | |
logging.error(f"Query processing timed out for {timeout_seconds} seconds...") | |
return "I'm sorry, but your request took too long..." | |
async def process_query(self, raw_user_query_this_turn: str) -> str: | |
if not client: return "Error: The AI Agent is not available..." | |
if not self._validate_query(raw_user_query_this_turn): return self._get_query_help_message() | |
readiness_check = await self._check_system_readiness() | |
if not readiness_check['ready']: return f"System not ready: {readiness_check['reason']}" | |
max_retries = 2 | |
for attempt in range(max_retries + 1): | |
try: | |
response_text = await self._process_query_with_timeout(raw_user_query_this_turn) | |
if "The AI's response was blocked" in response_text: return response_text | |
logging.info(f"Successfully received AI response (attempt {attempt+1}): {response_text[:100]}") | |
return response_text | |
except RateLimitError as rle: | |
if attempt == max_retries: return "The AI service is currently busy..." | |
await asyncio.sleep(2 ** attempt) | |
except ValidationError as ve: return f"Query validation failed: {str(ve)}" | |
except Exception as e: | |
if attempt == max_retries: return self._get_fallback_response(raw_user_query_this_turn) | |
return self._get_fallback_response(raw_user_query_this_turn) | |
def _classify_query_type(self, query: str) -> str: | |
query_lower = query.lower() | |
if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis' | |
elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis' | |
elif any(word in query_lower for word in ['predict', 'forecast', 'future']): return 'predictive_analysis' | |
elif any(word in query_lower for word in ['recommend', 'suggest', 'improve', 'advice', 'help me with']): return 'recommendation_engine' | |
elif any(word in query_lower for word in ['what is', 'explain', 'define']): return 'definition_explanation' | |
else: return 'general_inquiry' | |
def clear_chat_history(self): | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent chat history cleared by request.") | |
def get_all_schemas_representation(all_dataframes: dict) -> str: | |
if not all_dataframes: return "No DataFrames are currently loaded." | |
schema_descriptions = ["DataFrames currently available in the application state:"] | |
for key, df in all_dataframes.items(): | |
df_name = f"df_{key}" | |
columns = ", ".join(df.columns) | |
shape = df.shape | |
if df.empty: | |
schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}" | |
else: | |
try: | |
sample_data_str = df.head(2).to_markdown(index=False) | |
except ImportError: | |
logging.warning("`tabulate` library not found. Falling back to `to_string()` for schema representation.") | |
sample_data_str = df.head(2).to_string(index=False) | |
except Exception as e: | |
logging.error(f"Error formatting DataFrame sample for {df_name} with to_markdown: {e}. Falling back to to_string().") | |
sample_data_str = df.head(2).to_string(index=False) | |
schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns: {columns}\n\n<details><summary>Sample Data (first 2 rows of {df_name}):</summary>\n\n```text\n{sample_data_str}\n```\n\n</details>") | |
schema_descriptions.append(schema) | |
return "\n".join(schema_descriptions) | |
async def test_rag_retrieval_accuracy(): | |
logging.info("Running RAG retrieval accuracy test...") | |
test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME | |
if not client: | |
logging.error("Cannot run RAG test: GenAI client not initialized.") | |
return | |
test_docs_data = { | |
'text': [ | |
'Strategies for improving employee engagement include regular feedback and recognition programs.', | |
'Effective talent acquisition requires a strong employer brand and a streamlined hiring process.', | |
'Company culture is a key driver of employee satisfaction and retention.', | |
'Analyzing LinkedIn post performance can reveal insights into content effectiveness.' | |
] | |
} | |
test_docs_df = pd.DataFrame(test_docs_data) | |
rag_system = AdvancedRAGSystem(test_docs_df, test_embedding_model) | |
logging.info("Test RAG: Initializing embeddings...") | |
await rag_system.initialize_embeddings() | |
if rag_system.embeddings is None or rag_system.embeddings.size == 0: | |
logging.error("Test RAG: Embeddings not initialized properly.") | |
return | |
test_queries = { | |
"employee engagement": "engagement", | |
"hiring talent": "acquisition", | |
"company culture": "culture", | |
"linkedin posts": "linkedin" | |
} | |
all_tests_passed = True | |
for query, keyword in test_queries.items(): | |
logging.info(f"Test RAG: Retrieving for query: '{query}'") | |
result = await rag_system.retrieve_relevant_info(query, top_k=1, min_similarity=0.1) | |
if result and keyword.lower() in result.lower(): | |
logging.info(f"Test RAG: PASSED for query '{query}'. Found relevant doc.") | |
else: | |
logging.error(f"Test RAG: FAILED for query '{query}'. Expected keyword '{keyword}', got: {result[:100]}...") | |
all_tests_passed = False | |
if all_tests_passed: logging.info("All RAG retrieval accuracy tests passed.") | |
else: logging.error("Some RAG retrieval accuracy tests FAILED.") | |