GuglielmoTor commited on
Commit
606d7ff
·
verified ·
1 Parent(s): 97bdf15

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +80 -201
eb_agent_module.py CHANGED
@@ -20,14 +20,16 @@ except ImportError:
20
  SafetySetting = None
21
  # Define HarmCategory and HarmBlockThreshold as inner classes or attributes for the dummy types
22
  class HarmCategory: # type: ignore
 
23
  HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
24
  HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
25
  HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
26
  HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
27
  class HarmBlockThreshold: # type: ignore
 
28
  BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
29
  BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
30
- BLOCK_NONE = "BLOCK_NONE"
31
 
32
  # --- Custom Exceptions ---
33
  class ValidationError(Exception):
@@ -54,17 +56,27 @@ GENERATION_CONFIG_PARAMS = {
54
  "candidate_count": 1,
55
  }
56
 
57
- # Updated to use types.HarmCategory and types.HarmBlockThreshold
58
- DEFAULT_SAFETY_SETTINGS = [
59
- {"category": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_HATE_SPEECH",
60
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
61
- {"category": types.HarmCategory.HARM_CATEGORY_HARASSMENT if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_HARASSMENT",
62
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
63
- {"category": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_SEXUALLY_EXPLICIT",
64
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
65
- {"category": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_DANGEROUS_CONTENT",
66
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
67
- ]
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  df_rag_documents = pd.DataFrame({
@@ -101,7 +113,6 @@ class AdvancedRAGSystem:
101
  if not text or not isinstance(text, str):
102
  raise ValueError("Cannot embed empty or non-string text.")
103
 
104
- # Ensure types.EmbedContentConfig is available before using it
105
  embed_config = None
106
  if types and hasattr(types, 'EmbedContentConfig'):
107
  embed_config = types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
@@ -148,7 +159,6 @@ class AdvancedRAGSystem:
148
  self.embeddings = np.array([])
149
 
150
  def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
151
- """Calculate normalized cosine similarity between a matrix of embeddings and a query vector."""
152
  query_vector = query_vector.flatten()
153
  norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
154
  normalized_embeddings_matrix = embeddings_matrix / (norm_matrix + 1e-8)
@@ -179,21 +189,15 @@ class AdvancedRAGSystem:
179
 
180
  try:
181
  similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
182
-
183
- if similarity_scores.size == 0:
184
- return ""
185
-
186
  relevant_indices_after_threshold = np.where(similarity_scores >= min_similarity)[0]
187
  if len(relevant_indices_after_threshold) == 0:
188
  logging.debug(f"No documents met the minimum similarity threshold of {min_similarity} for query: {query[:50]}")
189
  return ""
190
-
191
  relevant_scores = similarity_scores[relevant_indices_after_threshold]
192
  sorted_relevant_indices_local = np.argsort(relevant_scores)[::-1]
193
  top_original_indices = relevant_indices_after_threshold[sorted_relevant_indices_local[:top_k]]
194
-
195
  if len(top_original_indices) == 0: return ""
196
-
197
  context_parts = [self.documents_df.iloc[i]['text'] for i in top_original_indices if 'text' in self.documents_df.columns]
198
  context = "\n\n---\n\n".join(context_parts)
199
  logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...")
@@ -210,15 +214,14 @@ class EmployerBrandingAgent:
210
  llm_model_name: str,
211
  embedding_model_name: str,
212
  generation_config_dict: dict,
213
- safety_settings_list_of_dicts: list,
214
  force_sandbox: bool = False):
215
  self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
216
  self.schemas_representation = self._get_enhanced_schemas_representation()
217
-
218
  self.chat_history = []
219
  self.llm_model_name = llm_model_name
220
  self.generation_config_dict = generation_config_dict
221
- self.safety_settings_list_of_dicts = safety_settings_list_of_dicts # These are dicts
222
  self.embedding_model_name = embedding_model_name
223
  self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
224
  self.force_sandbox = force_sandbox
@@ -236,7 +239,7 @@ class EmployerBrandingAgent:
236
  return "N/A"
237
 
238
  def _calculate_growth_rate(self, df: pd.DataFrame) -> str:
239
- logging.debug("_calculate_growth_rate is a placeholder.") # Changed to debug
240
  return "Growth rate calculation not implemented."
241
  def _analyze_engagement_trends(self, df: pd.DataFrame) -> str:
242
  logging.debug("_analyze_engagement_trends is a placeholder.")
@@ -257,21 +260,12 @@ class EmployerBrandingAgent:
257
  def _calculate_key_metrics(self, df: pd.DataFrame, df_type: str) -> dict:
258
  metrics = {}
259
  if 'follower' in df_type.lower():
260
- metrics.update({
261
- 'follower_growth_rate': self._calculate_growth_rate(df),
262
- 'engagement_trends': self._analyze_engagement_trends(df),
263
- 'demographic_distribution': self._analyze_demographics(df)
264
- })
265
  elif 'post' in df_type.lower():
266
- metrics.update({
267
- 'post_performance': self._analyze_post_performance(df),
268
- 'content_themes': self._extract_content_themes(df),
269
- 'optimal_posting_times': self._find_optimal_times(df)
270
- })
271
  elif 'mention' in df_type.lower():
272
  metrics['mention_volume_trend'] = "Mention volume trend not implemented."
273
  metrics['mention_sentiment_overview'] = "Mention sentiment overview not implemented."
274
-
275
  if not metrics:
276
  logging.debug(f"No specific key metrics defined for df_type: {df_type}")
277
  return {"info": "Standard metrics applicable."}
@@ -283,7 +277,7 @@ class EmployerBrandingAgent:
283
  try:
284
  max_date = df[col].max()
285
  if pd.notna(max_date):
286
- days_diff = (datetime.now(max_date.tzinfo) - max_date).days # tz aware
287
  return f"Data up to {max_date.strftime('%Y-%m-%d')} ({days_diff} days old)"
288
  except Exception: pass
289
  return "Freshness N/A (no clear date column)"
@@ -296,28 +290,17 @@ class EmployerBrandingAgent:
296
 
297
  def _assess_data_quality(self, df: pd.DataFrame) -> dict:
298
  completeness = (1 - (df.isnull().sum().sum() / (len(df) * len(df.columns)))) if len(df) > 0 and len(df.columns) > 0 else 0
299
- return {
300
- 'completeness_score': f"{completeness:.2%}",
301
- 'freshness_info': self._calculate_data_freshness(df),
302
- 'consistency_check': self._check_data_consistency(df),
303
- 'accuracy_flags_summary': self._identify_accuracy_issues(df),
304
- 'sample_size_notes': f"{len(df)} records. {'Adequate for basic analysis.' if len(df) >= 100 else 'Limited sample size; insights may be indicative.'}"
305
- }
306
 
307
  def _identify_patterns(self, df: pd.DataFrame, key: str) -> str:
308
  logging.debug(f"_identify_patterns for {key} is a placeholder.")
309
  return "Pattern identification not implemented."
310
 
311
  def _format_df_analysis(self, df_key: str, analysis: dict) -> str:
312
- formatted_parts = [f"\n--- DataFrame: df_{df_key} ---"]
313
- formatted_parts.append(f" Shape: {analysis['shape']}")
314
- formatted_parts.append(f" Date Range: {analysis['date_range']}")
315
- formatted_parts.append(" Key Metrics:")
316
- for metric, value in analysis['key_metrics'].items():
317
- formatted_parts.append(f" - {metric.replace('_', ' ').title()}: {value}")
318
  formatted_parts.append(" Data Quality Assessment:")
319
- for aspect, value in analysis['data_quality'].items():
320
- formatted_parts.append(f" - {aspect.replace('_', ' ').title()}: {value}")
321
  formatted_parts.append(f" Notable Patterns: {analysis['notable_patterns']}")
322
  return "\n".join(formatted_parts)
323
 
@@ -330,13 +313,7 @@ class EmployerBrandingAgent:
330
  if df.empty:
331
  schema_descriptions.append(f"\n--- DataFrame: df_{key} ---\nStatus: Empty. No analysis possible.")
332
  continue
333
- analysis = {
334
- 'shape': df.shape,
335
- 'date_range': self._get_date_range(df),
336
- 'key_metrics': self._calculate_key_metrics(df, key),
337
- 'data_quality': self._assess_data_quality(df),
338
- 'notable_patterns': self._identify_patterns(df, key)
339
- }
340
  schema_descriptions.append(self._format_df_analysis(key, analysis))
341
  return "\n".join(schema_descriptions)
342
 
@@ -367,108 +344,48 @@ Relevant Benchmarks: {benchmarks_val}"""
367
  return enhanced_context
368
 
369
  async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
370
- prompt_parts = [
371
- "You are an expert Employer Branding Analyst and a helpful AI assistant. "
372
- "Your goal is to provide insightful analysis based on the provided LinkedIn data. "
373
- "When asked to generate Pandas code, ensure it is correct, runnable, and clearly explained. "
374
- "When providing insights, be specific and refer to the data where possible. "
375
- "Use the detailed data overview and any contextual information provided."
376
- ]
377
- prompt_parts.append("\n\n--- DETAILED DATA OVERVIEW ---")
378
- prompt_parts.append(self.schemas_representation)
379
-
380
  if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
381
- logging.debug(f"Retrieving base RAG context for query: {raw_user_query[:50]}...")
382
  base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
383
  if base_rag_context:
384
- logging.debug(f"Enhancing RAG context for query: {raw_user_query[:50]}...")
385
  enhanced_rag_context = await self._enhance_rag_context(raw_user_query, base_rag_context)
386
- prompt_parts.append("\n\n--- RELEVANT CONTEXTUAL INFORMATION (from documents & business knowledge) ---")
387
- prompt_parts.append(enhanced_rag_context)
388
- else: logging.debug("No base RAG context found.")
389
- else: logging.debug("RAG system not initialized or embeddings not available, skipping RAG context retrieval.")
390
-
391
- prompt_parts.append("\n\n--- USER REQUEST ---")
392
- prompt_parts.append(f"Based on all the information above, please respond to the following user query:\n{raw_user_query}")
393
  final_prompt = "\n".join(prompt_parts)
394
  logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
395
  return final_prompt
396
 
397
  async def _process_structured_query(self, prompt: str) -> dict:
398
- logging.debug("_process_structured_query is a placeholder. Returning dummy structure.")
399
- return {
400
- "Key Findings": ["Placeholder finding 1", "Placeholder finding 2"],
401
- "Performance Metrics": ["Placeholder metric performance"],
402
- "Actionable Recommendations": {
403
- "Immediate Actions (0-30 days)": ["Placeholder immediate action"],
404
- "Short-term Strategy (1-3 months)": ["Placeholder short-term strategy"],
405
- "Long-term Vision (3-12 months)": ["Placeholder long-term vision"]
406
- },
407
- "Risk Assessment": ["Placeholder risk"],
408
- "Success Metrics to Track": ["Placeholder KPI"]
409
- }
410
 
411
  async def _generate_hr_insights(self, query: str, context: str) -> str:
412
- insight_prompt = f"""
413
- As an expert HR analytics consultant, analyze the following LinkedIn employer branding data:
414
- {context}
415
- User Query: {query}
416
- Please provide insights in this structured format:
417
- ## Key Findings
418
- - [3-5 bullet points of main discoveries]
419
- ## Performance Metrics
420
- - Current performance vs industry benchmarks
421
- - Trend analysis (improving/declining/stable)
422
- ## Actionable Recommendations
423
- 1. **Immediate Actions (0-30 days)**
424
- - [Specific, measurable actions]
425
- 2. **Short-term Strategy (1-3 months)**
426
- - [Strategic initiatives]
427
- 3. **Long-term Vision (3-12 months)**
428
- - [Comprehensive improvements]
429
- ## Risk Assessment
430
- - Potential challenges or red flags
431
- - Mitigation strategies
432
- ## Success Metrics to Track
433
- - KPIs to monitor progress
434
- - Reporting frequency recommendations
435
- """
436
  if not client: return "Error: AI client not configured for generating HR insights."
437
  api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
438
 
439
- # Construct SafetySetting objects if types.SafetySetting is available
440
  api_safety_settings_objects = []
441
  if types and hasattr(types, 'SafetySetting'):
442
  for ss_dict in self.safety_settings_list_of_dicts:
443
  try:
444
- # Use types.HarmCategory and types.HarmBlockThreshold directly
445
- category = getattr(types.HarmCategory, ss_dict['category'].split('.')[-1] if isinstance(ss_dict['category'], str) else ss_dict['category'].name, types.HarmCategory.HARM_CATEGORY_UNSPECIFIED)
446
- threshold = getattr(types.HarmBlockThreshold, ss_dict['threshold'].split('.')[-1] if isinstance(ss_dict['threshold'], str) else ss_dict['threshold'].name, types.HarmBlockThreshold.BLOCK_NONE)
447
- api_safety_settings_objects.append(types.SafetySetting(category=category, threshold=threshold))
448
- except Exception as e_ss:
449
- logging.warning(f"Could not create SafetySetting object from {ss_dict}: {e_ss}. Using dict.")
450
- api_safety_settings_objects.append(ss_dict) # Fallback to dict if creation fails
451
- else: # Fallback if types.SafetySetting is not available
452
  api_safety_settings_objects = self.safety_settings_list_of_dicts
453
 
454
  api_generation_config_obj = None
455
  if types and hasattr(types, 'GenerateContentConfig'):
456
- api_generation_config_obj = types.GenerateContentConfig(
457
- **self.generation_config_dict,
458
- safety_settings=api_safety_settings_objects
459
- )
460
- else: # Fallback if types.GenerateContentConfig is not available
461
- config_dict_for_api = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
462
- api_generation_config_obj = config_dict_for_api
463
-
464
 
465
  try:
466
- response = await asyncio.to_thread(
467
- client.models.generate_content,
468
- model=self.llm_model_name,
469
- contents=api_call_contents,
470
- config=api_generation_config_obj
471
- )
472
  if not response.candidates: return "HR insights generation failed: No response from AI."
473
  return response.text.strip()
474
  except Exception as e:
@@ -476,29 +393,23 @@ Please provide insights in this structured format:
476
  return f"Error generating HR insights: {str(e)}"
477
 
478
  def _validate_query(self, query: str) -> bool:
479
- if not query or len(query.strip()) < 3:
480
- logging.warning(f"Query too short: '{query}'")
481
- return False
482
  hr_keywords = ['employee', 'talent', 'hiring', 'culture', 'brand', 'engagement', 'retention', 'follower', 'post', 'mention', 'linkedin']
483
- if not any(keyword in query.lower() for keyword in hr_keywords):
484
- logging.warning(f"Query may not be HR/LinkedIn-relevant: {query[:50]}")
485
  return True
486
 
487
  def _get_query_help_message(self) -> str:
488
- return ("I'm here to help with Employer Branding analysis on LinkedIn data. "
489
- "Please ask specific questions about your followers, posts, or mentions. "
490
- "For example: 'What are the top industries of my followers?' or 'Analyze the engagement trend of my recent posts.'")
491
 
492
  async def _check_system_readiness(self) -> dict:
493
  logging.debug("_check_system_readiness is a placeholder.")
494
  if not client: return {'ready': False, 'reason': 'AI Client not initialized.'}
495
- if self.rag_system.embeddings is None:
496
- logging.warning("RAG embeddings not yet initialized. Proceeding, but RAG context will be unavailable.")
497
  return {'ready': True, 'reason': 'System appears ready.'}
498
 
499
  def _get_fallback_response(self, query: str) -> str:
500
  logging.error(f"Executing fallback response for query: {query[:50]}")
501
- return "I encountered an unexpected issue while processing your request. Please try rephrasing your query or try again later."
502
 
503
  async def _core_query_processing(self, raw_user_query_this_turn: str) -> str:
504
  augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn)
@@ -510,69 +421,40 @@ Please provide insights in this structured format:
510
  if types and hasattr(types, 'SafetySetting'):
511
  for ss_dict in self.safety_settings_list_of_dicts:
512
  try:
513
- category_enum_val = ss_dict['category']
514
- threshold_enum_val = ss_dict['threshold']
515
- # If they are already enum members, use them directly
516
- if not isinstance(category_enum_val, str): # Assumes it's an enum member
517
- category = category_enum_val
518
- else: # If string, try to get from types.HarmCategory
519
- category = getattr(types.HarmCategory, category_enum_val.split('.')[-1], types.HarmCategory.HARM_CATEGORY_UNSPECIFIED)
520
-
521
- if not isinstance(threshold_enum_val, str): # Assumes it's an enum member
522
- threshold = threshold_enum_val
523
- else: # If string, try to get from types.HarmBlockThreshold
524
- threshold = getattr(types.HarmBlockThreshold, threshold_enum_val.split('.')[-1], types.HarmBlockThreshold.BLOCK_NONE)
525
-
526
- api_safety_settings_objects.append(types.SafetySetting(category=category, threshold=threshold))
527
  except Exception as e_ss_core:
528
- logging.warning(f"Could not create SafetySetting object from {ss_dict} in core: {e_ss_core}. Using dict.")
529
- api_safety_settings_objects.append(ss_dict) # Fallback
530
- else:
531
  api_safety_settings_objects = self.safety_settings_list_of_dicts
532
 
533
-
534
  api_generation_config_obj = None
535
  if types and hasattr(types, 'GenerateContentConfig'):
536
- api_generation_config_obj = types.GenerateContentConfig(
537
- **self.generation_config_dict,
538
- safety_settings=api_safety_settings_objects
539
- )
540
  else:
541
  logging.error("GenerateContentConfig type not available. API call might fail.")
542
- config_dict_for_api = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
543
- api_generation_config_obj = config_dict_for_api
544
-
545
- response = await asyncio.to_thread(
546
- client.models.generate_content,
547
- model=self.llm_model_name,
548
- contents=api_call_contents,
549
- config=api_generation_config_obj
550
- )
551
 
 
552
  if not response.candidates:
553
  block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
554
  block_message = response.prompt_feedback.block_reason_message if response.prompt_feedback else ""
555
- logging.warning(f"AI response blocked or empty. Reason: {block_reason}, Msg: {block_message}")
556
- error_message = f"The AI's response was blocked. Reason: {block_reason}."
557
- if block_message: error_message += f" Details: {block_message}"
558
  return error_message
559
  return response.text.strip()
560
 
561
  async def _process_query_with_timeout(self, raw_user_query_this_turn: str, timeout_seconds: int = 60) -> str:
562
- try:
563
- return await asyncio.wait_for(self._core_query_processing(raw_user_query_this_turn), timeout=timeout_seconds)
564
  except asyncio.TimeoutError:
565
- logging.error(f"Query processing timed out after {timeout_seconds} seconds for query: {raw_user_query_this_turn[:50]}")
566
- return "I'm sorry, but your request took too long to process. Please try a simpler query or try again later."
567
 
568
  async def process_query(self, raw_user_query_this_turn: str) -> str:
569
- if not client:
570
- logging.error("GenAI client not initialized. Cannot process query.")
571
- return "Error: The AI Agent is not available due to a configuration issue with the AI service."
572
  if not self._validate_query(raw_user_query_this_turn): return self._get_query_help_message()
573
  readiness_check = await self._check_system_readiness()
574
  if not readiness_check['ready']: return f"System not ready: {readiness_check['reason']}"
575
-
576
  max_retries = 2
577
  for attempt in range(max_retries + 1):
578
  try:
@@ -581,22 +463,15 @@ Please provide insights in this structured format:
581
  logging.info(f"Successfully received AI response (attempt {attempt+1}): {response_text[:100]}")
582
  return response_text
583
  except RateLimitError as rle:
584
- logging.warning(f"Rate limit encountered on attempt {attempt + 1}: {rle}. Retrying after backoff...")
585
- if attempt == max_retries:
586
- logging.error("Max retries reached due to rate limiting.")
587
- return "The AI service is currently busy. Please try again in a few moments."
588
  await asyncio.sleep(2 ** attempt)
589
- except ValidationError as ve:
590
- logging.warning(f"Validation error during processing: {ve}")
591
- return f"Query validation failed: {str(ve)}"
592
  except Exception as e:
593
- logging.error(f"Error during GenAI call on attempt {attempt + 1}: {e}", exc_info=True)
594
- if attempt == max_retries:
595
- logging.error("Max retries reached due to general errors.")
596
- return self._get_fallback_response(raw_user_query_this_turn)
597
  return self._get_fallback_response(raw_user_query_this_turn)
598
 
599
  def _classify_query_type(self, query: str) -> str:
 
600
  query_lower = query.lower()
601
  if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
602
  elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
@@ -605,11 +480,13 @@ Please provide insights in this structured format:
605
  elif any(word in query_lower for word in ['what is', 'explain', 'define']): return 'definition_explanation'
606
  else: return 'general_inquiry'
607
 
 
608
  def clear_chat_history(self):
609
  self.chat_history = []
610
  logging.info("EmployerBrandingAgent chat history cleared by request.")
611
 
612
  def get_all_schemas_representation(all_dataframes: dict) -> str:
 
613
  if not all_dataframes: return "No DataFrames are currently loaded."
614
  schema_descriptions = ["DataFrames currently available in the application state:"]
615
  for key, df in all_dataframes.items():
@@ -624,7 +501,9 @@ def get_all_schemas_representation(all_dataframes: dict) -> str:
624
  schema_descriptions.append(schema)
625
  return "\n".join(schema_descriptions)
626
 
 
627
  async def test_rag_retrieval_accuracy():
 
628
  logging.info("Running RAG retrieval accuracy test...")
629
  test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
630
  if not client:
 
20
  SafetySetting = None
21
  # Define HarmCategory and HarmBlockThreshold as inner classes or attributes for the dummy types
22
  class HarmCategory: # type: ignore
23
+ HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
24
  HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
25
  HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
26
  HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
27
  HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
28
  class HarmBlockThreshold: # type: ignore
29
+ BLOCK_NONE = "BLOCK_NONE"
30
  BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
31
  BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
32
+ BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed
33
 
34
  # --- Custom Exceptions ---
35
  class ValidationError(Exception):
 
56
  "candidate_count": 1,
57
  }
58
 
59
+ # Corrected to use direct enum members when available
60
+ DEFAULT_SAFETY_SETTINGS = []
61
+ if types and hasattr(types, 'HarmCategory') and hasattr(types, 'HarmBlockThreshold'):
62
+ DEFAULT_SAFETY_SETTINGS = [
63
+ {"category": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
64
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
65
+ {"category": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
66
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
67
+ {"category": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
68
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
69
+ {"category": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
70
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
71
+ ]
72
+ else: # Fallback to strings if types or enums are not properly imported
73
+ logging.warning("Falling back to string representations for DEFAULT_SAFETY_SETTINGS due to missing types.HarmCategory or types.HarmBlockThreshold.")
74
+ DEFAULT_SAFETY_SETTINGS = [
75
+ {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
76
+ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
77
+ {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
78
+ {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
79
+ ]
80
 
81
 
82
  df_rag_documents = pd.DataFrame({
 
113
  if not text or not isinstance(text, str):
114
  raise ValueError("Cannot embed empty or non-string text.")
115
 
 
116
  embed_config = None
117
  if types and hasattr(types, 'EmbedContentConfig'):
118
  embed_config = types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
 
159
  self.embeddings = np.array([])
160
 
161
  def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
 
162
  query_vector = query_vector.flatten()
163
  norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
164
  normalized_embeddings_matrix = embeddings_matrix / (norm_matrix + 1e-8)
 
189
 
190
  try:
191
  similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
192
+ if similarity_scores.size == 0: return ""
 
 
 
193
  relevant_indices_after_threshold = np.where(similarity_scores >= min_similarity)[0]
194
  if len(relevant_indices_after_threshold) == 0:
195
  logging.debug(f"No documents met the minimum similarity threshold of {min_similarity} for query: {query[:50]}")
196
  return ""
 
197
  relevant_scores = similarity_scores[relevant_indices_after_threshold]
198
  sorted_relevant_indices_local = np.argsort(relevant_scores)[::-1]
199
  top_original_indices = relevant_indices_after_threshold[sorted_relevant_indices_local[:top_k]]
 
200
  if len(top_original_indices) == 0: return ""
 
201
  context_parts = [self.documents_df.iloc[i]['text'] for i in top_original_indices if 'text' in self.documents_df.columns]
202
  context = "\n\n---\n\n".join(context_parts)
203
  logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...")
 
214
  llm_model_name: str,
215
  embedding_model_name: str,
216
  generation_config_dict: dict,
217
+ safety_settings_list_of_dicts: list, # This list now contains dicts with ENUM values or STRINGS
218
  force_sandbox: bool = False):
219
  self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
220
  self.schemas_representation = self._get_enhanced_schemas_representation()
 
221
  self.chat_history = []
222
  self.llm_model_name = llm_model_name
223
  self.generation_config_dict = generation_config_dict
224
+ self.safety_settings_list_of_dicts = safety_settings_list_of_dicts
225
  self.embedding_model_name = embedding_model_name
226
  self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
227
  self.force_sandbox = force_sandbox
 
239
  return "N/A"
240
 
241
  def _calculate_growth_rate(self, df: pd.DataFrame) -> str:
242
+ logging.debug("_calculate_growth_rate is a placeholder.")
243
  return "Growth rate calculation not implemented."
244
  def _analyze_engagement_trends(self, df: pd.DataFrame) -> str:
245
  logging.debug("_analyze_engagement_trends is a placeholder.")
 
260
  def _calculate_key_metrics(self, df: pd.DataFrame, df_type: str) -> dict:
261
  metrics = {}
262
  if 'follower' in df_type.lower():
263
+ metrics.update({'follower_growth_rate': self._calculate_growth_rate(df), 'engagement_trends': self._analyze_engagement_trends(df), 'demographic_distribution': self._analyze_demographics(df)})
 
 
 
 
264
  elif 'post' in df_type.lower():
265
+ metrics.update({'post_performance': self._analyze_post_performance(df), 'content_themes': self._extract_content_themes(df), 'optimal_posting_times': self._find_optimal_times(df)})
 
 
 
 
266
  elif 'mention' in df_type.lower():
267
  metrics['mention_volume_trend'] = "Mention volume trend not implemented."
268
  metrics['mention_sentiment_overview'] = "Mention sentiment overview not implemented."
 
269
  if not metrics:
270
  logging.debug(f"No specific key metrics defined for df_type: {df_type}")
271
  return {"info": "Standard metrics applicable."}
 
277
  try:
278
  max_date = df[col].max()
279
  if pd.notna(max_date):
280
+ days_diff = (datetime.now(max_date.tzinfo if max_date.tzinfo else None) - max_date).days
281
  return f"Data up to {max_date.strftime('%Y-%m-%d')} ({days_diff} days old)"
282
  except Exception: pass
283
  return "Freshness N/A (no clear date column)"
 
290
 
291
  def _assess_data_quality(self, df: pd.DataFrame) -> dict:
292
  completeness = (1 - (df.isnull().sum().sum() / (len(df) * len(df.columns)))) if len(df) > 0 and len(df.columns) > 0 else 0
293
+ return {'completeness_score': f"{completeness:.2%}", 'freshness_info': self._calculate_data_freshness(df), 'consistency_check': self._check_data_consistency(df), 'accuracy_flags_summary': self._identify_accuracy_issues(df), 'sample_size_notes': f"{len(df)} records. {'Adequate for basic analysis.' if len(df) >= 100 else 'Limited sample size; insights may be indicative.'}"}
 
 
 
 
 
 
294
 
295
  def _identify_patterns(self, df: pd.DataFrame, key: str) -> str:
296
  logging.debug(f"_identify_patterns for {key} is a placeholder.")
297
  return "Pattern identification not implemented."
298
 
299
  def _format_df_analysis(self, df_key: str, analysis: dict) -> str:
300
+ formatted_parts = [f"\n--- DataFrame: df_{df_key} ---", f" Shape: {analysis['shape']}", f" Date Range: {analysis['date_range']}", " Key Metrics:"]
301
+ for metric, value in analysis['key_metrics'].items(): formatted_parts.append(f" - {metric.replace('_', ' ').title()}: {value}")
 
 
 
 
302
  formatted_parts.append(" Data Quality Assessment:")
303
+ for aspect, value in analysis['data_quality'].items(): formatted_parts.append(f" - {aspect.replace('_', ' ').title()}: {value}")
 
304
  formatted_parts.append(f" Notable Patterns: {analysis['notable_patterns']}")
305
  return "\n".join(formatted_parts)
306
 
 
313
  if df.empty:
314
  schema_descriptions.append(f"\n--- DataFrame: df_{key} ---\nStatus: Empty. No analysis possible.")
315
  continue
316
+ analysis = {'shape': df.shape, 'date_range': self._get_date_range(df), 'key_metrics': self._calculate_key_metrics(df, key), 'data_quality': self._assess_data_quality(df), 'notable_patterns': self._identify_patterns(df, key)}
 
 
 
 
 
 
317
  schema_descriptions.append(self._format_df_analysis(key, analysis))
318
  return "\n".join(schema_descriptions)
319
 
 
344
  return enhanced_context
345
 
346
  async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
347
+ prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation] # Truncated for brevity
 
 
 
 
 
 
 
 
 
348
  if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
 
349
  base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
350
  if base_rag_context:
 
351
  enhanced_rag_context = await self._enhance_rag_context(raw_user_query, base_rag_context)
352
+ prompt_parts.extend(["--- RELEVANT CONTEXTUAL INFORMATION (from documents & business knowledge) ---", enhanced_rag_context])
353
+ prompt_parts.extend(["--- USER REQUEST ---", f"Based on all the information above, please respond to the following user query:\n{raw_user_query}"])
 
 
 
 
 
354
  final_prompt = "\n".join(prompt_parts)
355
  logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
356
  return final_prompt
357
 
358
  async def _process_structured_query(self, prompt: str) -> dict:
359
+ logging.debug("_process_structured_query is a placeholder.")
360
+ return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
 
 
 
 
 
 
 
 
 
 
361
 
362
  async def _generate_hr_insights(self, query: str, context: str) -> str:
363
+ insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..." # Truncated
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  if not client: return "Error: AI client not configured for generating HR insights."
365
  api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
366
 
 
367
  api_safety_settings_objects = []
368
  if types and hasattr(types, 'SafetySetting'):
369
  for ss_dict in self.safety_settings_list_of_dicts:
370
  try:
371
+ # Directly use the category and threshold from the dict,
372
+ # which should be enum members if types was available at DEFAULT_SAFETY_SETTINGS definition,
373
+ # or strings otherwise.
374
+ api_safety_settings_objects.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
375
+ except Exception as e_ss: # Catch if ss_dict values are not valid for SafetySetting
376
+ logging.warning(f"Could not create SafetySetting object from {ss_dict} for HR insights: {e_ss}. Using raw dict.")
377
+ api_safety_settings_objects.append(ss_dict)
378
+ else:
379
  api_safety_settings_objects = self.safety_settings_list_of_dicts
380
 
381
  api_generation_config_obj = None
382
  if types and hasattr(types, 'GenerateContentConfig'):
383
+ api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
384
+ else:
385
+ api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
 
 
 
 
 
386
 
387
  try:
388
+ response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj)
 
 
 
 
 
389
  if not response.candidates: return "HR insights generation failed: No response from AI."
390
  return response.text.strip()
391
  except Exception as e:
 
393
  return f"Error generating HR insights: {str(e)}"
394
 
395
  def _validate_query(self, query: str) -> bool:
396
+ if not query or len(query.strip()) < 3: logging.warning(f"Query too short: '{query}'"); return False
 
 
397
  hr_keywords = ['employee', 'talent', 'hiring', 'culture', 'brand', 'engagement', 'retention', 'follower', 'post', 'mention', 'linkedin']
398
+ if not any(keyword in query.lower() for keyword in hr_keywords): logging.warning(f"Query may not be HR/LinkedIn-relevant: {query[:50]}")
 
399
  return True
400
 
401
  def _get_query_help_message(self) -> str:
402
+ return "I'm here to help with Employer Branding analysis... Example: 'What are the top industries of my followers?'"
 
 
403
 
404
  async def _check_system_readiness(self) -> dict:
405
  logging.debug("_check_system_readiness is a placeholder.")
406
  if not client: return {'ready': False, 'reason': 'AI Client not initialized.'}
407
+ if self.rag_system.embeddings is None: logging.warning("RAG embeddings not yet initialized.")
 
408
  return {'ready': True, 'reason': 'System appears ready.'}
409
 
410
  def _get_fallback_response(self, query: str) -> str:
411
  logging.error(f"Executing fallback response for query: {query[:50]}")
412
+ return "I encountered an unexpected issue..."
413
 
414
  async def _core_query_processing(self, raw_user_query_this_turn: str) -> str:
415
  augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn)
 
421
  if types and hasattr(types, 'SafetySetting'):
422
  for ss_dict in self.safety_settings_list_of_dicts:
423
  try:
424
+ # Directly use category/threshold from ss_dict. They should be enums or valid strings.
425
+ api_safety_settings_objects.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
 
 
 
 
 
 
 
 
 
 
 
 
426
  except Exception as e_ss_core:
427
+ logging.warning(f"Could not create SafetySetting object from {ss_dict} in core: {e_ss_core}. Using raw dict.")
428
+ api_safety_settings_objects.append(ss_dict) # Fallback to passing the dict itself
429
+ else: # Fallback if types.SafetySetting is not available
430
  api_safety_settings_objects = self.safety_settings_list_of_dicts
431
 
 
432
  api_generation_config_obj = None
433
  if types and hasattr(types, 'GenerateContentConfig'):
434
+ api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
 
 
 
435
  else:
436
  logging.error("GenerateContentConfig type not available. API call might fail.")
437
+ api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
 
 
 
 
 
 
 
 
438
 
439
+ response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj)
440
  if not response.candidates:
441
  block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
442
  block_message = response.prompt_feedback.block_reason_message if response.prompt_feedback else ""
443
+ error_message = f"The AI's response was blocked. Reason: {block_reason}." + (f" Details: {block_message}" if block_message else "")
 
 
444
  return error_message
445
  return response.text.strip()
446
 
447
  async def _process_query_with_timeout(self, raw_user_query_this_turn: str, timeout_seconds: int = 60) -> str:
448
+ try: return await asyncio.wait_for(self._core_query_processing(raw_user_query_this_turn), timeout=timeout_seconds)
 
449
  except asyncio.TimeoutError:
450
+ logging.error(f"Query processing timed out for {timeout_seconds} seconds...")
451
+ return "I'm sorry, but your request took too long..."
452
 
453
  async def process_query(self, raw_user_query_this_turn: str) -> str:
454
+ if not client: return "Error: The AI Agent is not available..."
 
 
455
  if not self._validate_query(raw_user_query_this_turn): return self._get_query_help_message()
456
  readiness_check = await self._check_system_readiness()
457
  if not readiness_check['ready']: return f"System not ready: {readiness_check['reason']}"
 
458
  max_retries = 2
459
  for attempt in range(max_retries + 1):
460
  try:
 
463
  logging.info(f"Successfully received AI response (attempt {attempt+1}): {response_text[:100]}")
464
  return response_text
465
  except RateLimitError as rle:
466
+ if attempt == max_retries: return "The AI service is currently busy..."
 
 
 
467
  await asyncio.sleep(2 ** attempt)
468
+ except ValidationError as ve: return f"Query validation failed: {str(ve)}"
 
 
469
  except Exception as e:
470
+ if attempt == max_retries: return self._get_fallback_response(raw_user_query_this_turn)
 
 
 
471
  return self._get_fallback_response(raw_user_query_this_turn)
472
 
473
  def _classify_query_type(self, query: str) -> str:
474
+ # ... (implementation unchanged)
475
  query_lower = query.lower()
476
  if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
477
  elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
 
480
  elif any(word in query_lower for word in ['what is', 'explain', 'define']): return 'definition_explanation'
481
  else: return 'general_inquiry'
482
 
483
+
484
  def clear_chat_history(self):
485
  self.chat_history = []
486
  logging.info("EmployerBrandingAgent chat history cleared by request.")
487
 
488
  def get_all_schemas_representation(all_dataframes: dict) -> str:
489
+ # ... (implementation unchanged)
490
  if not all_dataframes: return "No DataFrames are currently loaded."
491
  schema_descriptions = ["DataFrames currently available in the application state:"]
492
  for key, df in all_dataframes.items():
 
501
  schema_descriptions.append(schema)
502
  return "\n".join(schema_descriptions)
503
 
504
+
505
  async def test_rag_retrieval_accuracy():
506
+ # ... (implementation unchanged, ensure client and types are checked if used here)
507
  logging.info("Running RAG retrieval accuracy test...")
508
  test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
509
  if not client: