GuglielmoTor commited on
Commit
003ceb6
·
verified ·
1 Parent(s): ee9de40

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +33 -53
eb_agent_module.py CHANGED
@@ -8,8 +8,8 @@ import textwrap
8
  from datetime import datetime # Added for date calculations
9
 
10
  try:
11
- from google import generativeai as genai
12
- from google.generativeai import types # For GenerateContentConfig, SafetySetting, HarmCategory, HarmBlockThreshold etc.
13
  except ImportError:
14
  logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
15
  # Define dummy classes/variables if import fails
@@ -30,7 +30,7 @@ except ImportError:
30
  BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
31
  BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
32
  BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed
33
-
34
  # --- Custom Exceptions ---
35
  class ValidationError(Exception):
36
  """Custom validation error for agent inputs"""
@@ -56,27 +56,9 @@ GENERATION_CONFIG_PARAMS = {
56
  "candidate_count": 1,
57
  }
58
 
59
- # Corrected to use direct enum members when available
60
- DEFAULT_SAFETY_SETTINGS = []
61
- if types and hasattr(types, 'HarmCategory') and hasattr(types, 'HarmBlockThreshold'):
62
- DEFAULT_SAFETY_SETTINGS = [
63
- {"category": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
64
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
65
- {"category": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
66
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
67
- {"category": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
68
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
69
- {"category": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
70
- "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
71
- ]
72
- else: # Fallback to strings if types or enums are not properly imported
73
- logging.warning("Falling back to string representations for DEFAULT_SAFETY_SETTINGS due to missing types.HarmCategory or types.HarmBlockThreshold.")
74
- DEFAULT_SAFETY_SETTINGS = [
75
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
76
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
77
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
78
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
79
- ]
80
 
81
 
82
  df_rag_documents = pd.DataFrame({
@@ -214,18 +196,19 @@ class EmployerBrandingAgent:
214
  llm_model_name: str,
215
  embedding_model_name: str,
216
  generation_config_dict: dict,
217
- safety_settings_list_of_dicts: list, # This list now contains dicts with ENUM values or STRINGS
218
  force_sandbox: bool = False):
219
  self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
220
  self.schemas_representation = self._get_enhanced_schemas_representation()
221
  self.chat_history = []
222
  self.llm_model_name = llm_model_name
223
  self.generation_config_dict = generation_config_dict
224
- self.safety_settings_list_of_dicts = safety_settings_list_of_dicts
 
225
  self.embedding_model_name = embedding_model_name
226
  self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
227
  self.force_sandbox = force_sandbox
228
- logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. RAG system created.")
229
 
230
  def _get_date_range(self, df: pd.DataFrame) -> str:
231
  for col in df.columns:
@@ -344,7 +327,7 @@ Relevant Benchmarks: {benchmarks_val}"""
344
  return enhanced_context
345
 
346
  async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
347
- prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation] # Truncated for brevity
348
  if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
349
  base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
350
  if base_rag_context:
@@ -360,28 +343,27 @@ Relevant Benchmarks: {benchmarks_val}"""
360
  return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
361
 
362
  async def _generate_hr_insights(self, query: str, context: str) -> str:
363
- insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..." # Truncated
364
  if not client: return "Error: AI client not configured for generating HR insights."
365
  api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
366
 
367
  api_safety_settings_objects = []
368
- if types and hasattr(types, 'SafetySetting'):
369
- for ss_dict in self.safety_settings_list_of_dicts:
 
370
  try:
371
- # Directly use the category and threshold from the dict,
372
- # which should be enum members if types was available at DEFAULT_SAFETY_SETTINGS definition,
373
- # or strings otherwise.
374
- api_safety_settings_objects.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
375
- except Exception as e_ss: # Catch if ss_dict values are not valid for SafetySetting
376
- logging.warning(f"Could not create SafetySetting object from {ss_dict} for HR insights: {e_ss}. Using raw dict.")
377
- api_safety_settings_objects.append(ss_dict)
378
- else:
379
- api_safety_settings_objects = self.safety_settings_list_of_dicts
380
 
381
  api_generation_config_obj = None
382
  if types and hasattr(types, 'GenerateContentConfig'):
383
  api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
384
- else:
385
  api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
386
 
387
  try:
@@ -418,21 +400,22 @@ Relevant Benchmarks: {benchmarks_val}"""
418
  logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
419
 
420
  api_safety_settings_objects = []
421
- if types and hasattr(types, 'SafetySetting'):
422
- for ss_dict in self.safety_settings_list_of_dicts:
 
423
  try:
424
- # Directly use category/threshold from ss_dict. They should be enums or valid strings.
425
- api_safety_settings_objects.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
426
  except Exception as e_ss_core:
427
- logging.warning(f"Could not create SafetySetting object from {ss_dict} in core: {e_ss_core}. Using raw dict.")
428
- api_safety_settings_objects.append(ss_dict) # Fallback to passing the dict itself
429
- else: # Fallback if types.SafetySetting is not available
430
- api_safety_settings_objects = self.safety_settings_list_of_dicts
 
431
 
432
  api_generation_config_obj = None
433
  if types and hasattr(types, 'GenerateContentConfig'):
434
  api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
435
- else:
436
  logging.error("GenerateContentConfig type not available. API call might fail.")
437
  api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
438
 
@@ -471,7 +454,6 @@ Relevant Benchmarks: {benchmarks_val}"""
471
  return self._get_fallback_response(raw_user_query_this_turn)
472
 
473
  def _classify_query_type(self, query: str) -> str:
474
- # ... (implementation unchanged)
475
  query_lower = query.lower()
476
  if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
477
  elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
@@ -496,7 +478,6 @@ def get_all_schemas_representation(all_dataframes: dict) -> str:
496
  schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
497
  else:
498
  try:
499
- # Attempt to use to_markdown, with a fallback to to_string
500
  sample_data_str = df.head(2).to_markdown(index=False)
501
  except ImportError:
502
  logging.warning("`tabulate` library not found. Falling back to `to_string()` for schema representation.")
@@ -511,7 +492,6 @@ def get_all_schemas_representation(all_dataframes: dict) -> str:
511
 
512
 
513
  async def test_rag_retrieval_accuracy():
514
- # ... (implementation unchanged, ensure client and types are checked if used here)
515
  logging.info("Running RAG retrieval accuracy test...")
516
  test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
517
  if not client:
 
8
  from datetime import datetime # Added for date calculations
9
 
10
  try:
11
+ from google import genai
12
+ from google.genai import types # For GenerateContentConfig, SafetySetting, HarmCategory, HarmBlockThreshold etc.
13
  except ImportError:
14
  logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
15
  # Define dummy classes/variables if import fails
 
30
  BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
31
  BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
32
  BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed
33
+
34
  # --- Custom Exceptions ---
35
  class ValidationError(Exception):
36
  """Custom validation error for agent inputs"""
 
56
  "candidate_count": 1,
57
  }
58
 
59
+ # No safety settings by default as per user request
60
+ DEFAULT_SAFETY_SETTINGS = []
61
+ logging.info("Default safety settings are now empty (no explicit client-side safety settings).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  df_rag_documents = pd.DataFrame({
 
196
  llm_model_name: str,
197
  embedding_model_name: str,
198
  generation_config_dict: dict,
199
+ safety_settings_list: list,
200
  force_sandbox: bool = False):
201
  self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
202
  self.schemas_representation = self._get_enhanced_schemas_representation()
203
  self.chat_history = []
204
  self.llm_model_name = llm_model_name
205
  self.generation_config_dict = generation_config_dict
206
+ # If an empty list is passed, it means no specific safety settings are enforced by the client.
207
+ self.safety_settings_list = safety_settings_list if safety_settings_list is not None else []
208
  self.embedding_model_name = embedding_model_name
209
  self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
210
  self.force_sandbox = force_sandbox
211
+ logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. Safety settings count: {len(self.safety_settings_list)}")
212
 
213
  def _get_date_range(self, df: pd.DataFrame) -> str:
214
  for col in df.columns:
 
327
  return enhanced_context
328
 
329
  async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
330
+ prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation]
331
  if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
332
  base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
333
  if base_rag_context:
 
343
  return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
344
 
345
  async def _generate_hr_insights(self, query: str, context: str) -> str:
346
+ insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..."
347
  if not client: return "Error: AI client not configured for generating HR insights."
348
  api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
349
 
350
  api_safety_settings_objects = []
351
+ # self.safety_settings_list is expected to be empty if no settings are desired
352
+ if types and hasattr(types, 'SafetySetting') and self.safety_settings_list:
353
+ for ss_item in self.safety_settings_list:
354
  try:
355
+ api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
356
+ except Exception as e_ss:
357
+ logging.warning(f"Could not create SafetySetting object from {ss_item} for HR insights: {e_ss}. Using raw item.")
358
+ api_safety_settings_objects.append(ss_item)
359
+ elif self.safety_settings_list: # Fallback if types.SafetySetting not available but list is not empty
360
+ api_safety_settings_objects = self.safety_settings_list
361
+
 
 
362
 
363
  api_generation_config_obj = None
364
  if types and hasattr(types, 'GenerateContentConfig'):
365
  api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
366
+ else: # Fallback if types.GenerateContentConfig is not available
367
  api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
368
 
369
  try:
 
400
  logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
401
 
402
  api_safety_settings_objects = []
403
+ # self.safety_settings_list is expected to be empty if no settings are desired
404
+ if types and hasattr(types, 'SafetySetting') and self.safety_settings_list:
405
+ for ss_item in self.safety_settings_list:
406
  try:
407
+ api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
 
408
  except Exception as e_ss_core:
409
+ logging.warning(f"Could not create SafetySetting object from {ss_item} in core: {e_ss_core}. Using raw item.")
410
+ api_safety_settings_objects.append(ss_item)
411
+ elif self.safety_settings_list : # Fallback if types.SafetySetting not available but list is not empty
412
+ api_safety_settings_objects = self.safety_settings_list
413
+
414
 
415
  api_generation_config_obj = None
416
  if types and hasattr(types, 'GenerateContentConfig'):
417
  api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
418
+ else: # Fallback if types.GenerateContentConfig is not available
419
  logging.error("GenerateContentConfig type not available. API call might fail.")
420
  api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
421
 
 
454
  return self._get_fallback_response(raw_user_query_this_turn)
455
 
456
  def _classify_query_type(self, query: str) -> str:
 
457
  query_lower = query.lower()
458
  if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
459
  elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
 
478
  schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
479
  else:
480
  try:
 
481
  sample_data_str = df.head(2).to_markdown(index=False)
482
  except ImportError:
483
  logging.warning("`tabulate` library not found. Falling back to `to_string()` for schema representation.")
 
492
 
493
 
494
  async def test_rag_retrieval_accuracy():
 
495
  logging.info("Running RAG retrieval accuracy test...")
496
  test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
497
  if not client: