Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +33 -53
eb_agent_module.py
CHANGED
@@ -8,8 +8,8 @@ import textwrap
|
|
8 |
from datetime import datetime # Added for date calculations
|
9 |
|
10 |
try:
|
11 |
-
from google import
|
12 |
-
from google.
|
13 |
except ImportError:
|
14 |
logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
|
15 |
# Define dummy classes/variables if import fails
|
@@ -30,7 +30,7 @@ except ImportError:
|
|
30 |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
31 |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
32 |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed
|
33 |
-
|
34 |
# --- Custom Exceptions ---
|
35 |
class ValidationError(Exception):
|
36 |
"""Custom validation error for agent inputs"""
|
@@ -56,27 +56,9 @@ GENERATION_CONFIG_PARAMS = {
|
|
56 |
"candidate_count": 1,
|
57 |
}
|
58 |
|
59 |
-
#
|
60 |
-
DEFAULT_SAFETY_SETTINGS = []
|
61 |
-
|
62 |
-
DEFAULT_SAFETY_SETTINGS = [
|
63 |
-
{"category": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
64 |
-
"threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
|
65 |
-
{"category": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
66 |
-
"threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
|
67 |
-
{"category": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
68 |
-
"threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
|
69 |
-
{"category": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
70 |
-
"threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
|
71 |
-
]
|
72 |
-
else: # Fallback to strings if types or enums are not properly imported
|
73 |
-
logging.warning("Falling back to string representations for DEFAULT_SAFETY_SETTINGS due to missing types.HarmCategory or types.HarmBlockThreshold.")
|
74 |
-
DEFAULT_SAFETY_SETTINGS = [
|
75 |
-
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
76 |
-
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
77 |
-
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
78 |
-
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
79 |
-
]
|
80 |
|
81 |
|
82 |
df_rag_documents = pd.DataFrame({
|
@@ -214,18 +196,19 @@ class EmployerBrandingAgent:
|
|
214 |
llm_model_name: str,
|
215 |
embedding_model_name: str,
|
216 |
generation_config_dict: dict,
|
217 |
-
|
218 |
force_sandbox: bool = False):
|
219 |
self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
|
220 |
self.schemas_representation = self._get_enhanced_schemas_representation()
|
221 |
self.chat_history = []
|
222 |
self.llm_model_name = llm_model_name
|
223 |
self.generation_config_dict = generation_config_dict
|
224 |
-
|
|
|
225 |
self.embedding_model_name = embedding_model_name
|
226 |
self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
|
227 |
self.force_sandbox = force_sandbox
|
228 |
-
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}.
|
229 |
|
230 |
def _get_date_range(self, df: pd.DataFrame) -> str:
|
231 |
for col in df.columns:
|
@@ -344,7 +327,7 @@ Relevant Benchmarks: {benchmarks_val}"""
|
|
344 |
return enhanced_context
|
345 |
|
346 |
async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
|
347 |
-
prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation]
|
348 |
if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
|
349 |
base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
|
350 |
if base_rag_context:
|
@@ -360,28 +343,27 @@ Relevant Benchmarks: {benchmarks_val}"""
|
|
360 |
return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
|
361 |
|
362 |
async def _generate_hr_insights(self, query: str, context: str) -> str:
|
363 |
-
insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..."
|
364 |
if not client: return "Error: AI client not configured for generating HR insights."
|
365 |
api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
|
366 |
|
367 |
api_safety_settings_objects = []
|
368 |
-
if
|
369 |
-
|
|
|
370 |
try:
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
api_safety_settings_objects.append(
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
else:
|
379 |
-
api_safety_settings_objects = self.safety_settings_list_of_dicts
|
380 |
|
381 |
api_generation_config_obj = None
|
382 |
if types and hasattr(types, 'GenerateContentConfig'):
|
383 |
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
|
384 |
-
else:
|
385 |
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
|
386 |
|
387 |
try:
|
@@ -418,21 +400,22 @@ Relevant Benchmarks: {benchmarks_val}"""
|
|
418 |
logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
|
419 |
|
420 |
api_safety_settings_objects = []
|
421 |
-
if
|
422 |
-
|
|
|
423 |
try:
|
424 |
-
|
425 |
-
api_safety_settings_objects.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
|
426 |
except Exception as e_ss_core:
|
427 |
-
logging.warning(f"Could not create SafetySetting object from {
|
428 |
-
api_safety_settings_objects.append(
|
429 |
-
|
430 |
-
api_safety_settings_objects = self.
|
|
|
431 |
|
432 |
api_generation_config_obj = None
|
433 |
if types and hasattr(types, 'GenerateContentConfig'):
|
434 |
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
|
435 |
-
else:
|
436 |
logging.error("GenerateContentConfig type not available. API call might fail.")
|
437 |
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
|
438 |
|
@@ -471,7 +454,6 @@ Relevant Benchmarks: {benchmarks_val}"""
|
|
471 |
return self._get_fallback_response(raw_user_query_this_turn)
|
472 |
|
473 |
def _classify_query_type(self, query: str) -> str:
|
474 |
-
# ... (implementation unchanged)
|
475 |
query_lower = query.lower()
|
476 |
if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
|
477 |
elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
|
@@ -496,7 +478,6 @@ def get_all_schemas_representation(all_dataframes: dict) -> str:
|
|
496 |
schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
|
497 |
else:
|
498 |
try:
|
499 |
-
# Attempt to use to_markdown, with a fallback to to_string
|
500 |
sample_data_str = df.head(2).to_markdown(index=False)
|
501 |
except ImportError:
|
502 |
logging.warning("`tabulate` library not found. Falling back to `to_string()` for schema representation.")
|
@@ -511,7 +492,6 @@ def get_all_schemas_representation(all_dataframes: dict) -> str:
|
|
511 |
|
512 |
|
513 |
async def test_rag_retrieval_accuracy():
|
514 |
-
# ... (implementation unchanged, ensure client and types are checked if used here)
|
515 |
logging.info("Running RAG retrieval accuracy test...")
|
516 |
test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
|
517 |
if not client:
|
|
|
8 |
from datetime import datetime # Added for date calculations
|
9 |
|
10 |
try:
|
11 |
+
from google import genai
|
12 |
+
from google.genai import types # For GenerateContentConfig, SafetySetting, HarmCategory, HarmBlockThreshold etc.
|
13 |
except ImportError:
|
14 |
logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
|
15 |
# Define dummy classes/variables if import fails
|
|
|
30 |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
31 |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
32 |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed
|
33 |
+
|
34 |
# --- Custom Exceptions ---
|
35 |
class ValidationError(Exception):
|
36 |
"""Custom validation error for agent inputs"""
|
|
|
56 |
"candidate_count": 1,
|
57 |
}
|
58 |
|
59 |
+
# No safety settings by default as per user request
|
60 |
+
DEFAULT_SAFETY_SETTINGS = []
|
61 |
+
logging.info("Default safety settings are now empty (no explicit client-side safety settings).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
|
64 |
df_rag_documents = pd.DataFrame({
|
|
|
196 |
llm_model_name: str,
|
197 |
embedding_model_name: str,
|
198 |
generation_config_dict: dict,
|
199 |
+
safety_settings_list: list,
|
200 |
force_sandbox: bool = False):
|
201 |
self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
|
202 |
self.schemas_representation = self._get_enhanced_schemas_representation()
|
203 |
self.chat_history = []
|
204 |
self.llm_model_name = llm_model_name
|
205 |
self.generation_config_dict = generation_config_dict
|
206 |
+
# If an empty list is passed, it means no specific safety settings are enforced by the client.
|
207 |
+
self.safety_settings_list = safety_settings_list if safety_settings_list is not None else []
|
208 |
self.embedding_model_name = embedding_model_name
|
209 |
self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
|
210 |
self.force_sandbox = force_sandbox
|
211 |
+
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. Safety settings count: {len(self.safety_settings_list)}")
|
212 |
|
213 |
def _get_date_range(self, df: pd.DataFrame) -> str:
|
214 |
for col in df.columns:
|
|
|
327 |
return enhanced_context
|
328 |
|
329 |
async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
|
330 |
+
prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation]
|
331 |
if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
|
332 |
base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
|
333 |
if base_rag_context:
|
|
|
343 |
return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
|
344 |
|
345 |
async def _generate_hr_insights(self, query: str, context: str) -> str:
|
346 |
+
insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..."
|
347 |
if not client: return "Error: AI client not configured for generating HR insights."
|
348 |
api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
|
349 |
|
350 |
api_safety_settings_objects = []
|
351 |
+
# self.safety_settings_list is expected to be empty if no settings are desired
|
352 |
+
if types and hasattr(types, 'SafetySetting') and self.safety_settings_list:
|
353 |
+
for ss_item in self.safety_settings_list:
|
354 |
try:
|
355 |
+
api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
|
356 |
+
except Exception as e_ss:
|
357 |
+
logging.warning(f"Could not create SafetySetting object from {ss_item} for HR insights: {e_ss}. Using raw item.")
|
358 |
+
api_safety_settings_objects.append(ss_item)
|
359 |
+
elif self.safety_settings_list: # Fallback if types.SafetySetting not available but list is not empty
|
360 |
+
api_safety_settings_objects = self.safety_settings_list
|
361 |
+
|
|
|
|
|
362 |
|
363 |
api_generation_config_obj = None
|
364 |
if types and hasattr(types, 'GenerateContentConfig'):
|
365 |
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
|
366 |
+
else: # Fallback if types.GenerateContentConfig is not available
|
367 |
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
|
368 |
|
369 |
try:
|
|
|
400 |
logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
|
401 |
|
402 |
api_safety_settings_objects = []
|
403 |
+
# self.safety_settings_list is expected to be empty if no settings are desired
|
404 |
+
if types and hasattr(types, 'SafetySetting') and self.safety_settings_list:
|
405 |
+
for ss_item in self.safety_settings_list:
|
406 |
try:
|
407 |
+
api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
|
|
|
408 |
except Exception as e_ss_core:
|
409 |
+
logging.warning(f"Could not create SafetySetting object from {ss_item} in core: {e_ss_core}. Using raw item.")
|
410 |
+
api_safety_settings_objects.append(ss_item)
|
411 |
+
elif self.safety_settings_list : # Fallback if types.SafetySetting not available but list is not empty
|
412 |
+
api_safety_settings_objects = self.safety_settings_list
|
413 |
+
|
414 |
|
415 |
api_generation_config_obj = None
|
416 |
if types and hasattr(types, 'GenerateContentConfig'):
|
417 |
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
|
418 |
+
else: # Fallback if types.GenerateContentConfig is not available
|
419 |
logging.error("GenerateContentConfig type not available. API call might fail.")
|
420 |
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
|
421 |
|
|
|
454 |
return self._get_fallback_response(raw_user_query_this_turn)
|
455 |
|
456 |
def _classify_query_type(self, query: str) -> str:
|
|
|
457 |
query_lower = query.lower()
|
458 |
if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
|
459 |
elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
|
|
|
478 |
schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
|
479 |
else:
|
480 |
try:
|
|
|
481 |
sample_data_str = df.head(2).to_markdown(index=False)
|
482 |
except ImportError:
|
483 |
logging.warning("`tabulate` library not found. Falling back to `to_string()` for schema representation.")
|
|
|
492 |
|
493 |
|
494 |
async def test_rag_retrieval_accuracy():
|
|
|
495 |
logging.info("Running RAG retrieval accuracy test...")
|
496 |
test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
|
497 |
if not client:
|