GuglielmoTor commited on
Commit
97bdf15
·
verified ·
1 Parent(s): 69061c0

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +472 -250
eb_agent_module.py CHANGED
@@ -5,63 +5,68 @@ import asyncio
5
  import logging
6
  import numpy as np
7
  import textwrap
 
8
 
9
  try:
10
  from google import generativeai as genai
11
- from google.generativeai import types as genai_types # For GenerateContentConfig, SafetySetting etc.
12
- from google.generativeai.types import HarmCategory, HarmBlockThreshold # Specific enums
13
  except ImportError:
14
  logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
15
- # Define dummy classes/variables if import fails, so app.py can try to run
16
- # (though app.py already has EB_AGENT_AVAILABLE check)
17
  class genai: Client = None # type: ignore
18
- class genai_types: # type: ignore
19
  EmbedContentConfig = None
20
  GenerateContentConfig = None
21
  SafetySetting = None
22
- class HarmCategory: # type: ignore
23
- HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
24
- HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
25
- HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
26
- HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
27
- class HarmBlockThreshold: # type: ignore
28
- BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
29
- BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
30
- BLOCK_NONE = "BLOCK_NONE"
31
-
 
 
 
 
 
 
 
 
 
32
 
33
  # --- Configuration Constants ---
34
- # These are defined here because app.py imports them.
35
- # User should ensure these are appropriate for their needs.
36
-
37
  GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
38
  if not GEMINI_API_KEY:
39
  logging.warning("GEMINI_API_KEY environment variable not set. EB Agent will not function.")
40
 
41
- # Model names (as used in app.py imports from this module)
42
- LLM_MODEL_NAME = "gemini-1.5-flash-latest" # Changed to 1.5-flash as it's generally preferred; user had 2.0-flash. Adjust if needed.
43
- GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Common embedding model; user had gemini-embedding-exp-03-07. Adjust if needed.
44
 
45
- # Default Generation Config (app.py imports this as EB_AGENT_GEN_CONFIG)
46
  GENERATION_CONFIG_PARAMS = {
47
  "temperature": 0.7,
48
  "top_p": 0.95,
49
  "top_k": 40,
50
  "max_output_tokens": 8192,
51
- "candidate_count": 1, # Important for non-streaming
52
- # "stop_sequences": [...] # Optional
53
  }
54
 
55
- # Default Safety Settings (app.py imports this as EB_AGENT_SAFETY_SETTINGS)
56
  DEFAULT_SAFETY_SETTINGS = [
57
- {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
58
- {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
59
- {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
60
- {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
 
 
 
 
61
  ]
62
 
63
- # Placeholder for RAG documents DataFrame (app.py imports this as eb_agent_default_rag_docs)
64
- # In a real application, this would be loaded from a file or database.
65
  df_rag_documents = pd.DataFrame({
66
  'text': [
67
  "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
@@ -72,12 +77,9 @@ df_rag_documents = pd.DataFrame({
72
  })
73
 
74
  # --- Client Initialization ---
75
- # This client will be used by the agent instances.
76
- # It's initialized once when the module is loaded.
77
  client = None
78
- if GEMINI_API_KEY and genai.Client: # Check if genai.Client is not None (due to dummy class on import error)
79
  try:
80
- # genai.configure(api_key=GEMINI_API_KEY) # Alternative: global configuration
81
  client = genai.Client(api_key=GEMINI_API_KEY)
82
  logging.info("Google GenAI client initialized successfully.")
83
  except Exception as e:
@@ -87,43 +89,35 @@ else:
87
 
88
 
89
  class AdvancedRAGSystem:
90
- """
91
- Handles Retrieval Augmented Generation by embedding documents and finding relevant context for queries.
92
- """
93
  def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
94
- self.documents_df = documents_df.copy() # Work on a copy
95
  self.embedding_model_name = embedding_model_name
96
- self.embeddings: np.ndarray | None = None # Populated by async initialize_embeddings
97
  logging.info(f"AdvancedRAGSystem initialized with embedding model: {self.embedding_model_name}")
98
 
99
  def _embed_single_document_sync(self, text: str) -> np.ndarray:
100
- """Synchronous helper to embed a single piece of text."""
101
  if not client:
102
  raise ConnectionError("GenAI client not initialized for RAG embedding.")
103
- if not text or not isinstance(text, str): # Basic validation
104
- logging.warning("Attempted to embed empty or non-string text. Returning zero vector.")
105
- # Attempt to get model's embedding dimension, otherwise use a common default (e.g., 768)
106
- # This is tricky without a live model call. For now, let's assume it will be filtered or handled.
107
- # If we must return a vector, its dimensionality needs to be known.
108
- # For simplicity, errors during embedding will be logged and might lead to skipping the doc.
109
  raise ValueError("Cannot embed empty or non-string text.")
 
 
 
 
 
110
 
111
- # Using client.models.embed_content as per user's provided snippets
112
  response = client.models.embed_content(
113
- model=self.embedding_model_name, # e.g., "text-embedding-004" or "gemini-embedding-exp-03-07"
114
- contents=text, # API takes 'contents' (plural) but can be a single string for single embedding
115
- config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") if genai_types.EmbedContentConfig else None
116
  )
117
- # Assuming response.embeddings is the list of floats for a single content string, as per user's snippet.
118
  return np.array(response.embeddings)
119
 
120
  async def initialize_embeddings(self):
121
- """Asynchronously embeds all documents in the documents_df. Should be called once."""
122
  if self.documents_df.empty:
123
  logging.info("RAG documents DataFrame is empty. No embeddings to initialize.")
124
  self.embeddings = np.array([])
125
  return
126
-
127
  if not client:
128
  logging.error("GenAI client not available for RAG embedding initialization.")
129
  self.embeddings = np.array([])
@@ -137,11 +131,10 @@ class AdvancedRAGSystem:
137
  logging.warning(f"Skipping document at index {index} due to invalid text: {text_to_embed}")
138
  continue
139
  try:
140
- # Wrap the synchronous SDK call in asyncio.to_thread
141
  embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
142
  embedded_docs_list.append(embedding_array)
143
  except Exception as e:
144
- logging.error(f"Error embedding document text (index {index}) '{str(text_to_embed)[:50]}...': {e}", exc_info=False) # exc_info=False for brevity in loop
145
 
146
  if not embedded_docs_list:
147
  self.embeddings = np.array([])
@@ -150,13 +143,20 @@ class AdvancedRAGSystem:
150
  try:
151
  self.embeddings = np.vstack(embedded_docs_list)
152
  logging.info(f"Successfully embedded {len(embedded_docs_list)} documents for RAG. Embedding matrix shape: {self.embeddings.shape}")
153
- except ValueError as ve: # Handles cases like empty list or inconsistent shapes if errors weren't caught properly
154
  logging.error(f"Error stacking embeddings: {ve}. Check individual embedding errors.", exc_info=True)
155
  self.embeddings = np.array([])
156
 
 
 
 
 
 
 
 
 
157
 
158
- async def retrieve_relevant_info(self, query: str, top_k: int = 3) -> str:
159
- """Retrieves relevant document snippets for a given query using vector similarity."""
160
  if self.embeddings is None or self.embeddings.size == 0 or self.documents_df.empty:
161
  logging.debug("RAG system not initialized or no documents/embeddings available for retrieval.")
162
  return ""
@@ -176,257 +176,441 @@ class AdvancedRAGSystem:
176
  if query_vector.ndim == 0 or query_vector.size == 0:
177
  logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}")
178
  return ""
179
- if query_vector.ndim > 1: # Should be 1D
180
- query_vector = query_vector.flatten()
181
-
182
  try:
183
- # Cosine similarity is dot product of normalized vectors.
184
- # For simplicity, using dot product directly. Normalize if true cosine sim is needed.
185
- scores = np.dot(self.embeddings, query_vector) # self.embeddings (N, D), query_vector (D,) -> scores (N,)
186
 
187
- if scores.size == 0:
188
  return ""
189
 
190
- actual_top_k = min(top_k, len(self.documents_df), len(scores))
191
- if actual_top_k <= 0: return "" # Ensure top_k is positive
192
-
193
- # Get indices of top_k scores in descending order
194
- top_indices = np.argsort(scores)[-actual_top_k:][::-1]
195
 
196
- valid_top_indices = [idx for idx in top_indices if 0 <= idx < len(self.documents_df)]
197
- if not valid_top_indices: return ""
 
 
 
198
 
199
- # Retrieve the 'text' field from the original DataFrame
200
- context_parts = [self.documents_df.iloc[i]['text'] for i in valid_top_indices if 'text' in self.documents_df.columns]
201
  context = "\n\n---\n\n".join(context_parts)
202
  logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...")
203
  return context
204
  except Exception as e:
205
- logging.error(f"Error during RAG retrieval (dot product/sorting): {e}", exc_info=True)
206
  return ""
207
 
208
 
209
  class EmployerBrandingAgent:
210
- """
211
- An agent that uses Generative AI to provide insights on employer branding
212
- based on provided DataFrames and RAG context.
213
- """
214
  def __init__(self,
215
  all_dataframes: dict,
216
- rag_documents_df: pd.DataFrame, # For RAG system
217
  llm_model_name: str,
218
- embedding_model_name: str, # For RAG system
219
  generation_config_dict: dict,
220
  safety_settings_list_of_dicts: list,
221
- # client_instance, # Using global client for simplicity now
222
- force_sandbox: bool = False # Parameter from app.py, currently unused here
223
- ):
224
- # self.client = client_instance # If client were passed
225
- self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()} # Work with copies
226
- self.schemas_representation = self._get_all_schemas_representation() # Sync method
227
 
228
- self.chat_history = [] # Stores chat in API format: [{"role": "user/model", "parts": [{"text": "..."}]}]
229
- # This will be set by app.py before calling process_query
230
-
231
  self.llm_model_name = llm_model_name
232
  self.generation_config_dict = generation_config_dict
233
- self.safety_settings_list_of_dicts = safety_settings_list_of_dicts
234
-
235
  self.embedding_model_name = embedding_model_name
236
  self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
237
- # Note: self.rag_system.initialize_embeddings() must be called externally (e.g., in app.py)
238
-
239
- self.force_sandbox = force_sandbox # Store if needed for tool use later
240
  logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. RAG system created.")
241
 
242
- def _get_all_schemas_representation(self) -> str:
243
- """Generates a string representation of the schemas of all DataFrames."""
244
- schema_descriptions = ["DataFrames available for analysis:"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  for key, df in self.all_dataframes.items():
246
- df_name = f"df_{key}" # Consistent naming for the agent to refer to
247
- columns = ", ".join(df.columns)
248
- shape = df.shape
249
  if df.empty:
250
- schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
251
- else:
252
- # Basic stats for numeric columns, first few unique for objects
253
- sample_info_parts = []
254
- for col in df.columns:
255
- if pd.api.types.is_numeric_dtype(df[col]) and not df[col].empty:
256
- sample_info_parts.append(f"{col} (numeric, e.g., mean: {df[col].mean():.2f})")
257
- elif pd.api.types.is_datetime64_any_dtype(df[col]) and not df[col].empty:
258
- sample_info_parts.append(f"{col} (datetime, e.g., min: {df[col].min()}, max: {df[col].max()})")
259
- elif not df[col].empty:
260
- unique_vals = df[col].unique()
261
- display_unique = ', '.join(map(str, unique_vals[:3]))
262
- if len(unique_vals) > 3: display_unique += ", ..."
263
- sample_info_parts.append(f"{col} (object, e.g., {display_unique})")
264
- else:
265
- sample_info_parts.append(f"{col} (empty)")
266
-
267
- schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns & Sample Info:\n " + "\n ".join(sample_info_parts))
268
- schema_descriptions.append(schema)
269
  return "\n".join(schema_descriptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
272
- """
273
- Constructs the full prompt for the current turn, including system instructions,
274
- DataFrame schemas, RAG context, and the user's query.
275
- """
276
- # System instruction part
277
  prompt_parts = [
278
  "You are an expert Employer Branding Analyst and a helpful AI assistant. "
279
  "Your goal is to provide insightful analysis based on the provided LinkedIn data. "
280
  "When asked to generate Pandas code, ensure it is correct, runnable, and clearly explained. "
281
- "When providing insights, be specific and refer to the data where possible."
 
282
  ]
283
-
284
- # Schema information
285
- prompt_parts.append("\n\n--- AVAILABLE DATA ---")
286
  prompt_parts.append(self.schemas_representation)
287
 
288
- # RAG context
289
- if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0 : # Check if RAG is initialized
290
- logging.debug(f"Retrieving RAG context for query: {raw_user_query[:50]}...")
291
- rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
292
- if rag_context:
293
- prompt_parts.append("\n\n--- RELEVANT CONTEXTUAL INFORMATION (from documents) ---")
294
- prompt_parts.append(rag_context)
295
- else:
296
- logging.debug("No relevant RAG context found.")
297
- else:
298
- logging.debug("RAG system not initialized or embeddings not available, skipping RAG context retrieval.")
299
-
300
 
301
- # User's current query
302
  prompt_parts.append("\n\n--- USER REQUEST ---")
303
  prompt_parts.append(f"Based on all the information above, please respond to the following user query:\n{raw_user_query}")
304
-
305
  final_prompt = "\n".join(prompt_parts)
306
  logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
307
  return final_prompt
308
 
309
- async def process_query(self, raw_user_query_this_turn: str) -> str:
310
- """
311
- Processes the user's query, incorporating chat history, DataFrame schemas, and RAG.
312
- The agent's self.chat_history is expected to be set by the calling application (app.py)
313
- and should contain the history *before* the current raw_user_query_this_turn.
314
- This method returns the AI's response string. app.py will then update the agent's
315
- chat history with the raw_user_query_this_turn and this response.
316
- """
317
- if not client:
318
- logging.error("GenAI client not initialized. Cannot process query.")
319
- return "Error: The AI Agent is not available due to a configuration issue with the AI service."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- if not raw_user_query_this_turn.strip():
322
- return "Please provide a query."
323
 
324
- # 1. Prepare the augmented prompt for the *current* user query
325
- # This prompt includes system instructions, schemas, RAG, and the current raw query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn)
327
-
328
- # 2. Construct the full list of contents for the API call
329
- # self.chat_history should be in API format: [{"role": "user/model", "parts": [{"text": "..."}]}]
330
- # It contains history *before* the current raw_user_query_this_turn.
331
- api_call_contents = []
332
- if self.chat_history: # Add previous turns if any
333
- api_call_contents.extend(self.chat_history)
334
-
335
- # Add the current user turn, using the fully augmented prompt as its content
336
  api_call_contents.append({"role": "user", "parts": [{"text": augmented_current_user_prompt_text}]})
337
-
338
  logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
339
- if api_call_contents:
340
- logging.debug(f"Last turn role: {api_call_contents[-1]['role']}, text start: {api_call_contents[-1]['parts'][0]['text'][:100]}")
341
-
342
 
343
- # 3. Prepare API configuration
344
- # Convert safety settings from list of dicts to list of SafetySetting objects if genai_types are available
345
- api_safety_settings = []
346
- if genai_types.SafetySetting:
347
  for ss_dict in self.safety_settings_list_of_dicts:
348
  try:
349
- api_safety_settings.append(genai_types.SafetySetting(**ss_dict))
350
- except TypeError: # Handles if HarmCategory/HarmBlockThreshold were strings due to import error
351
- logging.warning(f"Could not create SafetySetting object from dict: {ss_dict}. Using dict directly.")
352
- api_safety_settings.append(ss_dict) # Fallback to dict
353
- else: # genai_types not available
354
- api_safety_settings = self.safety_settings_list_of_dicts
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
- api_generation_config = None
358
- if genai_types.GenerateContentConfig:
359
- try:
360
- api_generation_config = genai_types.GenerateContentConfig(
361
- **self.generation_config_dict,
362
- safety_settings=api_safety_settings # This should be list of SafetySetting objects or dicts
363
- )
364
- except TypeError:
365
- logging.warning("Could not create GenerateContentConfig object. Using dicts directly for config.")
366
- # Fallback: if GenerateContentConfig fails, try to pass dicts (might not be supported by client.models.generate_content's 'config' param)
367
- # The user's snippet uses config=types.GenerateContentConfig(...), so this object is important.
368
- # If it fails, the call might fail.
369
- api_generation_config = self.generation_config_dict # This is not ideal for the 'config' parameter.
370
- # The 'config' parameter of client.models.generate_content expects a GenerateContentConfig object.
371
- # If we can't create it, we should signal an error or try a different call structure if available.
372
- # For now, proceed and let the API call potentially fail if config is malformed.
373
- # A better fallback would be to construct the config parts individually if the main object fails.
374
- # However, the user's snippet is clear: config=types.GenerateContentConfig(...)
375
- # So, if genai_types.GenerateContentConfig is None, this will be an issue.
376
-
377
- else: # genai_types.GenerateContentConfig is None (likely import error)
378
- logging.error("genai_types.GenerateContentConfig not available. Cannot form API config.")
379
- return "Error: AI Agent configuration problem (GenerateContentConfig type missing)."
380
-
381
-
382
- # 4. Make the API call (synchronous SDK call wrapped in asyncio.to_thread)
383
- try:
384
- response = await asyncio.to_thread(
385
- client.models.generate_content, # As per user's snippet
386
- model=self.llm_model_name,
387
- contents=api_call_contents,
388
- config=api_generation_config # Pass the GenerateContentConfig object
389
  )
390
- # Extract text. User's snippet uses response.text
391
- # Check for blocked content or other issues
392
- if not response.candidates:
393
- block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
394
- logging.warning(f"AI response blocked or empty. Reason: {block_reason}")
395
- # You might want to inspect response.prompt_feedback for block reasons
396
- error_message = f"The AI's response was blocked. Reason: {block_reason}."
397
- if response.prompt_feedback and response.prompt_feedback.block_reason_message:
398
- error_message += f" Details: {response.prompt_feedback.block_reason_message}"
399
- return error_message
400
-
401
- answer = response.text.strip()
402
- logging.info(f"Successfully received AI response (first 100 chars): {answer[:100]}")
403
 
404
- except Exception as e:
405
- logging.error(f"Error during GenAI call: {e}", exc_info=True)
406
- # Check if it's a Google specific API error for more details
407
- # from google.api_core import exceptions as google_exceptions
408
- # if isinstance(e, google_exceptions.GoogleAPIError):
409
- # answer = f"API Error: {e.message}"
410
- # else:
411
- answer = f"# Error during AI processing:\n{type(e).__name__}: {str(e)}"
412
-
413
- return answer
 
 
 
 
 
414
 
415
- def clear_chat_history(self): # This method is called by app.py
416
- """Clears the agent's internal chat history."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  self.chat_history = []
418
  logging.info("EmployerBrandingAgent chat history cleared by request.")
419
 
420
- # --- Module-level function for schema display in app.py ---
421
  def get_all_schemas_representation(all_dataframes: dict) -> str:
422
- """
423
- Generates a string representation of the schemas of all DataFrames,
424
- intended for display in the Gradio UI.
425
- This is a standalone function as it's imported directly by app.py.
426
- """
427
- if not all_dataframes:
428
- return "No DataFrames are currently loaded."
429
-
430
  schema_descriptions = ["DataFrames currently available in the application state:"]
431
  for key, df in all_dataframes.items():
432
  df_name = f"df_{key}"
@@ -435,9 +619,47 @@ def get_all_schemas_representation(all_dataframes: dict) -> str:
435
  if df.empty:
436
  schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
437
  else:
438
- # Provide a bit more detail for UI display
439
- sample_data_str = df.head(2).to_markdown(index=False) # Use markdown for better UI rendering
440
  schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns: {columns}\n\n<details><summary>Sample Data (first 2 rows of {df_name}):</summary>\n\n{sample_data_str}\n\n</details>")
441
  schema_descriptions.append(schema)
442
  return "\n".join(schema_descriptions)
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import logging
6
  import numpy as np
7
  import textwrap
8
+ from datetime import datetime # Added for date calculations
9
 
10
  try:
11
  from google import generativeai as genai
12
+ from google.generativeai import types # For GenerateContentConfig, SafetySetting, HarmCategory, HarmBlockThreshold etc.
 
13
  except ImportError:
14
  logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
15
+ # Define dummy classes/variables if import fails
 
16
  class genai: Client = None # type: ignore
17
+ class types: # type: ignore
18
  EmbedContentConfig = None
19
  GenerateContentConfig = None
20
  SafetySetting = None
21
+ # Define HarmCategory and HarmBlockThreshold as inner classes or attributes for the dummy types
22
+ class HarmCategory: # type: ignore
23
+ HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
24
+ HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
25
+ HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
26
+ HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
27
+ class HarmBlockThreshold: # type: ignore
28
+ BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
29
+ BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
30
+ BLOCK_NONE = "BLOCK_NONE"
31
+
32
+ # --- Custom Exceptions ---
33
+ class ValidationError(Exception):
34
+ """Custom validation error for agent inputs"""
35
+ pass
36
+
37
+ class RateLimitError(Exception):
38
+ """Placeholder for rate limit errors."""
39
+ pass
40
 
41
  # --- Configuration Constants ---
 
 
 
42
  GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
43
  if not GEMINI_API_KEY:
44
  logging.warning("GEMINI_API_KEY environment variable not set. EB Agent will not function.")
45
 
46
+ LLM_MODEL_NAME = "gemini-1.5-flash-latest"
47
+ GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004"
 
48
 
 
49
  GENERATION_CONFIG_PARAMS = {
50
  "temperature": 0.7,
51
  "top_p": 0.95,
52
  "top_k": 40,
53
  "max_output_tokens": 8192,
54
+ "candidate_count": 1,
 
55
  }
56
 
57
+ # Updated to use types.HarmCategory and types.HarmBlockThreshold
58
  DEFAULT_SAFETY_SETTINGS = [
59
+ {"category": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_HATE_SPEECH",
60
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
61
+ {"category": types.HarmCategory.HARM_CATEGORY_HARASSMENT if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_HARASSMENT",
62
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
63
+ {"category": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_SEXUALLY_EXPLICIT",
64
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
65
+ {"category": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT if types and hasattr(types, 'HarmCategory') else "HARM_CATEGORY_DANGEROUS_CONTENT",
66
+ "threshold": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE if types and hasattr(types, 'HarmBlockThreshold') else "BLOCK_MEDIUM_AND_ABOVE"},
67
  ]
68
 
69
+
 
70
  df_rag_documents = pd.DataFrame({
71
  'text': [
72
  "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
 
77
  })
78
 
79
  # --- Client Initialization ---
 
 
80
  client = None
81
+ if GEMINI_API_KEY and genai.Client:
82
  try:
 
83
  client = genai.Client(api_key=GEMINI_API_KEY)
84
  logging.info("Google GenAI client initialized successfully.")
85
  except Exception as e:
 
89
 
90
 
91
  class AdvancedRAGSystem:
 
 
 
92
  def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
93
+ self.documents_df = documents_df.copy()
94
  self.embedding_model_name = embedding_model_name
95
+ self.embeddings: np.ndarray | None = None
96
  logging.info(f"AdvancedRAGSystem initialized with embedding model: {self.embedding_model_name}")
97
 
98
  def _embed_single_document_sync(self, text: str) -> np.ndarray:
 
99
  if not client:
100
  raise ConnectionError("GenAI client not initialized for RAG embedding.")
101
+ if not text or not isinstance(text, str):
 
 
 
 
 
102
  raise ValueError("Cannot embed empty or non-string text.")
103
+
104
+ # Ensure types.EmbedContentConfig is available before using it
105
+ embed_config = None
106
+ if types and hasattr(types, 'EmbedContentConfig'):
107
+ embed_config = types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
108
 
 
109
  response = client.models.embed_content(
110
+ model=self.embedding_model_name,
111
+ contents=text,
112
+ config=embed_config
113
  )
 
114
  return np.array(response.embeddings)
115
 
116
  async def initialize_embeddings(self):
 
117
  if self.documents_df.empty:
118
  logging.info("RAG documents DataFrame is empty. No embeddings to initialize.")
119
  self.embeddings = np.array([])
120
  return
 
121
  if not client:
122
  logging.error("GenAI client not available for RAG embedding initialization.")
123
  self.embeddings = np.array([])
 
131
  logging.warning(f"Skipping document at index {index} due to invalid text: {text_to_embed}")
132
  continue
133
  try:
 
134
  embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
135
  embedded_docs_list.append(embedding_array)
136
  except Exception as e:
137
+ logging.error(f"Error embedding document text (index {index}) '{str(text_to_embed)[:50]}...': {e}", exc_info=False)
138
 
139
  if not embedded_docs_list:
140
  self.embeddings = np.array([])
 
143
  try:
144
  self.embeddings = np.vstack(embedded_docs_list)
145
  logging.info(f"Successfully embedded {len(embedded_docs_list)} documents for RAG. Embedding matrix shape: {self.embeddings.shape}")
146
+ except ValueError as ve:
147
  logging.error(f"Error stacking embeddings: {ve}. Check individual embedding errors.", exc_info=True)
148
  self.embeddings = np.array([])
149
 
150
+ def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
151
+ """Calculate normalized cosine similarity between a matrix of embeddings and a query vector."""
152
+ query_vector = query_vector.flatten()
153
+ norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
154
+ normalized_embeddings_matrix = embeddings_matrix / (norm_matrix + 1e-8)
155
+ norm_query = np.linalg.norm(query_vector)
156
+ normalized_query_vector = query_vector / (norm_query + 1e-8)
157
+ return np.dot(normalized_embeddings_matrix, normalized_query_vector)
158
 
159
+ async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
 
160
  if self.embeddings is None or self.embeddings.size == 0 or self.documents_df.empty:
161
  logging.debug("RAG system not initialized or no documents/embeddings available for retrieval.")
162
  return ""
 
176
  if query_vector.ndim == 0 or query_vector.size == 0:
177
  logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}")
178
  return ""
179
+
 
 
180
  try:
181
+ similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
 
 
182
 
183
+ if similarity_scores.size == 0:
184
  return ""
185
 
186
+ relevant_indices_after_threshold = np.where(similarity_scores >= min_similarity)[0]
187
+ if len(relevant_indices_after_threshold) == 0:
188
+ logging.debug(f"No documents met the minimum similarity threshold of {min_similarity} for query: {query[:50]}")
189
+ return ""
 
190
 
191
+ relevant_scores = similarity_scores[relevant_indices_after_threshold]
192
+ sorted_relevant_indices_local = np.argsort(relevant_scores)[::-1]
193
+ top_original_indices = relevant_indices_after_threshold[sorted_relevant_indices_local[:top_k]]
194
+
195
+ if len(top_original_indices) == 0: return ""
196
 
197
+ context_parts = [self.documents_df.iloc[i]['text'] for i in top_original_indices if 'text' in self.documents_df.columns]
 
198
  context = "\n\n---\n\n".join(context_parts)
199
  logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...")
200
  return context
201
  except Exception as e:
202
+ logging.error(f"Error during RAG retrieval (similarity/sorting): {e}", exc_info=True)
203
  return ""
204
 
205
 
206
  class EmployerBrandingAgent:
 
 
 
 
207
  def __init__(self,
208
  all_dataframes: dict,
209
+ rag_documents_df: pd.DataFrame,
210
  llm_model_name: str,
211
+ embedding_model_name: str,
212
  generation_config_dict: dict,
213
  safety_settings_list_of_dicts: list,
214
+ force_sandbox: bool = False):
215
+ self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
216
+ self.schemas_representation = self._get_enhanced_schemas_representation()
 
 
 
217
 
218
+ self.chat_history = []
 
 
219
  self.llm_model_name = llm_model_name
220
  self.generation_config_dict = generation_config_dict
221
+ self.safety_settings_list_of_dicts = safety_settings_list_of_dicts # These are dicts
 
222
  self.embedding_model_name = embedding_model_name
223
  self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
224
+ self.force_sandbox = force_sandbox
 
 
225
  logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. RAG system created.")
226
 
227
+ def _get_date_range(self, df: pd.DataFrame) -> str:
228
+ for col in df.columns:
229
+ if pd.api.types.is_datetime64_any_dtype(df[col]):
230
+ try:
231
+ min_date = df[col].min()
232
+ max_date = df[col].max()
233
+ if pd.notna(min_date) and pd.notna(max_date):
234
+ return f"{min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}"
235
+ except Exception: pass
236
+ return "N/A"
237
+
238
+ def _calculate_growth_rate(self, df: pd.DataFrame) -> str:
239
+ logging.debug("_calculate_growth_rate is a placeholder.") # Changed to debug
240
+ return "Growth rate calculation not implemented."
241
+ def _analyze_engagement_trends(self, df: pd.DataFrame) -> str:
242
+ logging.debug("_analyze_engagement_trends is a placeholder.")
243
+ return "Engagement trend analysis not implemented."
244
+ def _analyze_demographics(self, df: pd.DataFrame) -> str:
245
+ logging.debug("_analyze_demographics is a placeholder.")
246
+ return "Demographic analysis not implemented."
247
+ def _analyze_post_performance(self, df: pd.DataFrame) -> str:
248
+ logging.debug("_analyze_post_performance is a placeholder.")
249
+ return "Post performance analysis not implemented."
250
+ def _extract_content_themes(self, df: pd.DataFrame) -> str:
251
+ logging.debug("_extract_content_themes is a placeholder.")
252
+ return "Content theme extraction not implemented."
253
+ def _find_optimal_times(self, df: pd.DataFrame) -> str:
254
+ logging.debug("_find_optimal_times is a placeholder.")
255
+ return "Optimal posting time analysis not implemented."
256
+
257
+ def _calculate_key_metrics(self, df: pd.DataFrame, df_type: str) -> dict:
258
+ metrics = {}
259
+ if 'follower' in df_type.lower():
260
+ metrics.update({
261
+ 'follower_growth_rate': self._calculate_growth_rate(df),
262
+ 'engagement_trends': self._analyze_engagement_trends(df),
263
+ 'demographic_distribution': self._analyze_demographics(df)
264
+ })
265
+ elif 'post' in df_type.lower():
266
+ metrics.update({
267
+ 'post_performance': self._analyze_post_performance(df),
268
+ 'content_themes': self._extract_content_themes(df),
269
+ 'optimal_posting_times': self._find_optimal_times(df)
270
+ })
271
+ elif 'mention' in df_type.lower():
272
+ metrics['mention_volume_trend'] = "Mention volume trend not implemented."
273
+ metrics['mention_sentiment_overview'] = "Mention sentiment overview not implemented."
274
+
275
+ if not metrics:
276
+ logging.debug(f"No specific key metrics defined for df_type: {df_type}")
277
+ return {"info": "Standard metrics applicable."}
278
+ return metrics
279
+
280
+ def _calculate_data_freshness(self, df: pd.DataFrame) -> str:
281
+ for col in df.columns:
282
+ if pd.api.types.is_datetime64_any_dtype(df[col]):
283
+ try:
284
+ max_date = df[col].max()
285
+ if pd.notna(max_date):
286
+ days_diff = (datetime.now(max_date.tzinfo) - max_date).days # tz aware
287
+ return f"Data up to {max_date.strftime('%Y-%m-%d')} ({days_diff} days old)"
288
+ except Exception: pass
289
+ return "Freshness N/A (no clear date column)"
290
+ def _check_data_consistency(self, df: pd.DataFrame) -> str:
291
+ logging.debug("_check_data_consistency is a placeholder.")
292
+ return "Consistency checks not implemented."
293
+ def _identify_accuracy_issues(self, df: pd.DataFrame) -> str:
294
+ logging.debug("_identify_accuracy_issues is a placeholder.")
295
+ return "Accuracy issue identification not implemented."
296
+
297
+ def _assess_data_quality(self, df: pd.DataFrame) -> dict:
298
+ completeness = (1 - (df.isnull().sum().sum() / (len(df) * len(df.columns)))) if len(df) > 0 and len(df.columns) > 0 else 0
299
+ return {
300
+ 'completeness_score': f"{completeness:.2%}",
301
+ 'freshness_info': self._calculate_data_freshness(df),
302
+ 'consistency_check': self._check_data_consistency(df),
303
+ 'accuracy_flags_summary': self._identify_accuracy_issues(df),
304
+ 'sample_size_notes': f"{len(df)} records. {'Adequate for basic analysis.' if len(df) >= 100 else 'Limited sample size; insights may be indicative.'}"
305
+ }
306
+
307
+ def _identify_patterns(self, df: pd.DataFrame, key: str) -> str:
308
+ logging.debug(f"_identify_patterns for {key} is a placeholder.")
309
+ return "Pattern identification not implemented."
310
+
311
+ def _format_df_analysis(self, df_key: str, analysis: dict) -> str:
312
+ formatted_parts = [f"\n--- DataFrame: df_{df_key} ---"]
313
+ formatted_parts.append(f" Shape: {analysis['shape']}")
314
+ formatted_parts.append(f" Date Range: {analysis['date_range']}")
315
+ formatted_parts.append(" Key Metrics:")
316
+ for metric, value in analysis['key_metrics'].items():
317
+ formatted_parts.append(f" - {metric.replace('_', ' ').title()}: {value}")
318
+ formatted_parts.append(" Data Quality Assessment:")
319
+ for aspect, value in analysis['data_quality'].items():
320
+ formatted_parts.append(f" - {aspect.replace('_', ' ').title()}: {value}")
321
+ formatted_parts.append(f" Notable Patterns: {analysis['notable_patterns']}")
322
+ return "\n".join(formatted_parts)
323
+
324
+ def _get_enhanced_schemas_representation(self) -> str:
325
+ schema_descriptions = ["=== DETAILED LINKEDIN DATA OVERVIEW ==="]
326
+ if not self.all_dataframes:
327
+ schema_descriptions.append("No dataframes available for analysis.")
328
+ return "\n".join(schema_descriptions)
329
  for key, df in self.all_dataframes.items():
 
 
 
330
  if df.empty:
331
+ schema_descriptions.append(f"\n--- DataFrame: df_{key} ---\nStatus: Empty. No analysis possible.")
332
+ continue
333
+ analysis = {
334
+ 'shape': df.shape,
335
+ 'date_range': self._get_date_range(df),
336
+ 'key_metrics': self._calculate_key_metrics(df, key),
337
+ 'data_quality': self._assess_data_quality(df),
338
+ 'notable_patterns': self._identify_patterns(df, key)
339
+ }
340
+ schema_descriptions.append(self._format_df_analysis(key, analysis))
 
 
 
 
 
 
 
 
 
341
  return "\n".join(schema_descriptions)
342
+
343
+ def _extract_query_intent(self, query: str) -> str:
344
+ logging.debug("_extract_query_intent is a placeholder.")
345
+ if "compare" in query.lower() or "benchmark" in query.lower(): return "comparison"
346
+ if "trend" in query.lower(): return "trend_analysis"
347
+ return "general"
348
+
349
+ async def _get_business_context(self, intent: str) -> str:
350
+ logging.debug("_get_business_context is a placeholder.")
351
+ if intent == "comparison": return "Company is focused on outperforming competitors in tech hiring."
352
+ return "Company aims to improve overall employer brand perception."
353
+
354
+ async def _get_industry_benchmarks(self, intent: str) -> str:
355
+ logging.debug("_get_industry_benchmarks is a placeholder.")
356
+ if intent == "trend_analysis": return "Typical follower growth in this sector is 5-10% MoM."
357
+ return "Average engagement rate for similar companies is 2-3%."
358
+
359
+ async def _enhance_rag_context(self, query: str, base_context: str) -> str:
360
+ intent = self._extract_query_intent(query)
361
+ business_context_val = await self._get_business_context(intent)
362
+ benchmarks_val = await self._get_industry_benchmarks(intent)
363
+ enhanced_context = f"""{base_context}
364
+ --- ADDITIONAL CONTEXT FOR YOUR ANALYSIS ---
365
+ Business Focus: {business_context_val}
366
+ Relevant Benchmarks: {benchmarks_val}"""
367
+ return enhanced_context
368
 
369
  async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
 
 
 
 
 
370
  prompt_parts = [
371
  "You are an expert Employer Branding Analyst and a helpful AI assistant. "
372
  "Your goal is to provide insightful analysis based on the provided LinkedIn data. "
373
  "When asked to generate Pandas code, ensure it is correct, runnable, and clearly explained. "
374
+ "When providing insights, be specific and refer to the data where possible. "
375
+ "Use the detailed data overview and any contextual information provided."
376
  ]
377
+ prompt_parts.append("\n\n--- DETAILED DATA OVERVIEW ---")
 
 
378
  prompt_parts.append(self.schemas_representation)
379
 
380
+ if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
381
+ logging.debug(f"Retrieving base RAG context for query: {raw_user_query[:50]}...")
382
+ base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
383
+ if base_rag_context:
384
+ logging.debug(f"Enhancing RAG context for query: {raw_user_query[:50]}...")
385
+ enhanced_rag_context = await self._enhance_rag_context(raw_user_query, base_rag_context)
386
+ prompt_parts.append("\n\n--- RELEVANT CONTEXTUAL INFORMATION (from documents & business knowledge) ---")
387
+ prompt_parts.append(enhanced_rag_context)
388
+ else: logging.debug("No base RAG context found.")
389
+ else: logging.debug("RAG system not initialized or embeddings not available, skipping RAG context retrieval.")
 
 
390
 
 
391
  prompt_parts.append("\n\n--- USER REQUEST ---")
392
  prompt_parts.append(f"Based on all the information above, please respond to the following user query:\n{raw_user_query}")
 
393
  final_prompt = "\n".join(prompt_parts)
394
  logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
395
  return final_prompt
396
 
397
+ async def _process_structured_query(self, prompt: str) -> dict:
398
+ logging.debug("_process_structured_query is a placeholder. Returning dummy structure.")
399
+ return {
400
+ "Key Findings": ["Placeholder finding 1", "Placeholder finding 2"],
401
+ "Performance Metrics": ["Placeholder metric performance"],
402
+ "Actionable Recommendations": {
403
+ "Immediate Actions (0-30 days)": ["Placeholder immediate action"],
404
+ "Short-term Strategy (1-3 months)": ["Placeholder short-term strategy"],
405
+ "Long-term Vision (3-12 months)": ["Placeholder long-term vision"]
406
+ },
407
+ "Risk Assessment": ["Placeholder risk"],
408
+ "Success Metrics to Track": ["Placeholder KPI"]
409
+ }
410
+
411
+ async def _generate_hr_insights(self, query: str, context: str) -> str:
412
+ insight_prompt = f"""
413
+ As an expert HR analytics consultant, analyze the following LinkedIn employer branding data:
414
+ {context}
415
+ User Query: {query}
416
+ Please provide insights in this structured format:
417
+ ## Key Findings
418
+ - [3-5 bullet points of main discoveries]
419
+ ## Performance Metrics
420
+ - Current performance vs industry benchmarks
421
+ - Trend analysis (improving/declining/stable)
422
+ ## Actionable Recommendations
423
+ 1. **Immediate Actions (0-30 days)**
424
+ - [Specific, measurable actions]
425
+ 2. **Short-term Strategy (1-3 months)**
426
+ - [Strategic initiatives]
427
+ 3. **Long-term Vision (3-12 months)**
428
+ - [Comprehensive improvements]
429
+ ## Risk Assessment
430
+ - Potential challenges or red flags
431
+ - Mitigation strategies
432
+ ## Success Metrics to Track
433
+ - KPIs to monitor progress
434
+ - Reporting frequency recommendations
435
+ """
436
+ if not client: return "Error: AI client not configured for generating HR insights."
437
+ api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
438
+
439
+ # Construct SafetySetting objects if types.SafetySetting is available
440
+ api_safety_settings_objects = []
441
+ if types and hasattr(types, 'SafetySetting'):
442
+ for ss_dict in self.safety_settings_list_of_dicts:
443
+ try:
444
+ # Use types.HarmCategory and types.HarmBlockThreshold directly
445
+ category = getattr(types.HarmCategory, ss_dict['category'].split('.')[-1] if isinstance(ss_dict['category'], str) else ss_dict['category'].name, types.HarmCategory.HARM_CATEGORY_UNSPECIFIED)
446
+ threshold = getattr(types.HarmBlockThreshold, ss_dict['threshold'].split('.')[-1] if isinstance(ss_dict['threshold'], str) else ss_dict['threshold'].name, types.HarmBlockThreshold.BLOCK_NONE)
447
+ api_safety_settings_objects.append(types.SafetySetting(category=category, threshold=threshold))
448
+ except Exception as e_ss:
449
+ logging.warning(f"Could not create SafetySetting object from {ss_dict}: {e_ss}. Using dict.")
450
+ api_safety_settings_objects.append(ss_dict) # Fallback to dict if creation fails
451
+ else: # Fallback if types.SafetySetting is not available
452
+ api_safety_settings_objects = self.safety_settings_list_of_dicts
453
+
454
+ api_generation_config_obj = None
455
+ if types and hasattr(types, 'GenerateContentConfig'):
456
+ api_generation_config_obj = types.GenerateContentConfig(
457
+ **self.generation_config_dict,
458
+ safety_settings=api_safety_settings_objects
459
+ )
460
+ else: # Fallback if types.GenerateContentConfig is not available
461
+ config_dict_for_api = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
462
+ api_generation_config_obj = config_dict_for_api
463
 
 
 
464
 
465
+ try:
466
+ response = await asyncio.to_thread(
467
+ client.models.generate_content,
468
+ model=self.llm_model_name,
469
+ contents=api_call_contents,
470
+ config=api_generation_config_obj
471
+ )
472
+ if not response.candidates: return "HR insights generation failed: No response from AI."
473
+ return response.text.strip()
474
+ except Exception as e:
475
+ logging.error(f"Error generating HR insights: {e}", exc_info=True)
476
+ return f"Error generating HR insights: {str(e)}"
477
+
478
+ def _validate_query(self, query: str) -> bool:
479
+ if not query or len(query.strip()) < 3:
480
+ logging.warning(f"Query too short: '{query}'")
481
+ return False
482
+ hr_keywords = ['employee', 'talent', 'hiring', 'culture', 'brand', 'engagement', 'retention', 'follower', 'post', 'mention', 'linkedin']
483
+ if not any(keyword in query.lower() for keyword in hr_keywords):
484
+ logging.warning(f"Query may not be HR/LinkedIn-relevant: {query[:50]}")
485
+ return True
486
+
487
+ def _get_query_help_message(self) -> str:
488
+ return ("I'm here to help with Employer Branding analysis on LinkedIn data. "
489
+ "Please ask specific questions about your followers, posts, or mentions. "
490
+ "For example: 'What are the top industries of my followers?' or 'Analyze the engagement trend of my recent posts.'")
491
+
492
+ async def _check_system_readiness(self) -> dict:
493
+ logging.debug("_check_system_readiness is a placeholder.")
494
+ if not client: return {'ready': False, 'reason': 'AI Client not initialized.'}
495
+ if self.rag_system.embeddings is None:
496
+ logging.warning("RAG embeddings not yet initialized. Proceeding, but RAG context will be unavailable.")
497
+ return {'ready': True, 'reason': 'System appears ready.'}
498
+
499
+ def _get_fallback_response(self, query: str) -> str:
500
+ logging.error(f"Executing fallback response for query: {query[:50]}")
501
+ return "I encountered an unexpected issue while processing your request. Please try rephrasing your query or try again later."
502
+
503
+ async def _core_query_processing(self, raw_user_query_this_turn: str) -> str:
504
  augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn)
505
+ api_call_contents = list(self.chat_history)
 
 
 
 
 
 
 
 
506
  api_call_contents.append({"role": "user", "parts": [{"text": augmented_current_user_prompt_text}]})
 
507
  logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
 
 
 
508
 
509
+ api_safety_settings_objects = []
510
+ if types and hasattr(types, 'SafetySetting'):
 
 
511
  for ss_dict in self.safety_settings_list_of_dicts:
512
  try:
513
+ category_enum_val = ss_dict['category']
514
+ threshold_enum_val = ss_dict['threshold']
515
+ # If they are already enum members, use them directly
516
+ if not isinstance(category_enum_val, str): # Assumes it's an enum member
517
+ category = category_enum_val
518
+ else: # If string, try to get from types.HarmCategory
519
+ category = getattr(types.HarmCategory, category_enum_val.split('.')[-1], types.HarmCategory.HARM_CATEGORY_UNSPECIFIED)
520
+
521
+ if not isinstance(threshold_enum_val, str): # Assumes it's an enum member
522
+ threshold = threshold_enum_val
523
+ else: # If string, try to get from types.HarmBlockThreshold
524
+ threshold = getattr(types.HarmBlockThreshold, threshold_enum_val.split('.')[-1], types.HarmBlockThreshold.BLOCK_NONE)
525
+
526
+ api_safety_settings_objects.append(types.SafetySetting(category=category, threshold=threshold))
527
+ except Exception as e_ss_core:
528
+ logging.warning(f"Could not create SafetySetting object from {ss_dict} in core: {e_ss_core}. Using dict.")
529
+ api_safety_settings_objects.append(ss_dict) # Fallback
530
+ else:
531
+ api_safety_settings_objects = self.safety_settings_list_of_dicts
532
 
533
 
534
+ api_generation_config_obj = None
535
+ if types and hasattr(types, 'GenerateContentConfig'):
536
+ api_generation_config_obj = types.GenerateContentConfig(
537
+ **self.generation_config_dict,
538
+ safety_settings=api_safety_settings_objects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  )
540
+ else:
541
+ logging.error("GenerateContentConfig type not available. API call might fail.")
542
+ config_dict_for_api = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
543
+ api_generation_config_obj = config_dict_for_api
544
+
545
+ response = await asyncio.to_thread(
546
+ client.models.generate_content,
547
+ model=self.llm_model_name,
548
+ contents=api_call_contents,
549
+ config=api_generation_config_obj
550
+ )
 
 
551
 
552
+ if not response.candidates:
553
+ block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
554
+ block_message = response.prompt_feedback.block_reason_message if response.prompt_feedback else ""
555
+ logging.warning(f"AI response blocked or empty. Reason: {block_reason}, Msg: {block_message}")
556
+ error_message = f"The AI's response was blocked. Reason: {block_reason}."
557
+ if block_message: error_message += f" Details: {block_message}"
558
+ return error_message
559
+ return response.text.strip()
560
+
561
+ async def _process_query_with_timeout(self, raw_user_query_this_turn: str, timeout_seconds: int = 60) -> str:
562
+ try:
563
+ return await asyncio.wait_for(self._core_query_processing(raw_user_query_this_turn), timeout=timeout_seconds)
564
+ except asyncio.TimeoutError:
565
+ logging.error(f"Query processing timed out after {timeout_seconds} seconds for query: {raw_user_query_this_turn[:50]}")
566
+ return "I'm sorry, but your request took too long to process. Please try a simpler query or try again later."
567
 
568
+ async def process_query(self, raw_user_query_this_turn: str) -> str:
569
+ if not client:
570
+ logging.error("GenAI client not initialized. Cannot process query.")
571
+ return "Error: The AI Agent is not available due to a configuration issue with the AI service."
572
+ if not self._validate_query(raw_user_query_this_turn): return self._get_query_help_message()
573
+ readiness_check = await self._check_system_readiness()
574
+ if not readiness_check['ready']: return f"System not ready: {readiness_check['reason']}"
575
+
576
+ max_retries = 2
577
+ for attempt in range(max_retries + 1):
578
+ try:
579
+ response_text = await self._process_query_with_timeout(raw_user_query_this_turn)
580
+ if "The AI's response was blocked" in response_text: return response_text
581
+ logging.info(f"Successfully received AI response (attempt {attempt+1}): {response_text[:100]}")
582
+ return response_text
583
+ except RateLimitError as rle:
584
+ logging.warning(f"Rate limit encountered on attempt {attempt + 1}: {rle}. Retrying after backoff...")
585
+ if attempt == max_retries:
586
+ logging.error("Max retries reached due to rate limiting.")
587
+ return "The AI service is currently busy. Please try again in a few moments."
588
+ await asyncio.sleep(2 ** attempt)
589
+ except ValidationError as ve:
590
+ logging.warning(f"Validation error during processing: {ve}")
591
+ return f"Query validation failed: {str(ve)}"
592
+ except Exception as e:
593
+ logging.error(f"Error during GenAI call on attempt {attempt + 1}: {e}", exc_info=True)
594
+ if attempt == max_retries:
595
+ logging.error("Max retries reached due to general errors.")
596
+ return self._get_fallback_response(raw_user_query_this_turn)
597
+ return self._get_fallback_response(raw_user_query_this_turn)
598
+
599
+ def _classify_query_type(self, query: str) -> str:
600
+ query_lower = query.lower()
601
+ if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
602
+ elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
603
+ elif any(word in query_lower for word in ['predict', 'forecast', 'future']): return 'predictive_analysis'
604
+ elif any(word in query_lower for word in ['recommend', 'suggest', 'improve', 'advice', 'help me with']): return 'recommendation_engine'
605
+ elif any(word in query_lower for word in ['what is', 'explain', 'define']): return 'definition_explanation'
606
+ else: return 'general_inquiry'
607
+
608
+ def clear_chat_history(self):
609
  self.chat_history = []
610
  logging.info("EmployerBrandingAgent chat history cleared by request.")
611
 
 
612
  def get_all_schemas_representation(all_dataframes: dict) -> str:
613
+ if not all_dataframes: return "No DataFrames are currently loaded."
 
 
 
 
 
 
 
614
  schema_descriptions = ["DataFrames currently available in the application state:"]
615
  for key, df in all_dataframes.items():
616
  df_name = f"df_{key}"
 
619
  if df.empty:
620
  schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
621
  else:
622
+ sample_data_str = df.head(2).to_markdown(index=False)
 
623
  schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns: {columns}\n\n<details><summary>Sample Data (first 2 rows of {df_name}):</summary>\n\n{sample_data_str}\n\n</details>")
624
  schema_descriptions.append(schema)
625
  return "\n".join(schema_descriptions)
626
 
627
+ async def test_rag_retrieval_accuracy():
628
+ logging.info("Running RAG retrieval accuracy test...")
629
+ test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
630
+ if not client:
631
+ logging.error("Cannot run RAG test: GenAI client not initialized.")
632
+ return
633
+ test_docs_data = {
634
+ 'text': [
635
+ 'Strategies for improving employee engagement include regular feedback and recognition programs.',
636
+ 'Effective talent acquisition requires a strong employer brand and a streamlined hiring process.',
637
+ 'Company culture is a key driver of employee satisfaction and retention.',
638
+ 'Analyzing LinkedIn post performance can reveal insights into content effectiveness.'
639
+ ]
640
+ }
641
+ test_docs_df = pd.DataFrame(test_docs_data)
642
+ rag_system = AdvancedRAGSystem(test_docs_df, test_embedding_model)
643
+ logging.info("Test RAG: Initializing embeddings...")
644
+ await rag_system.initialize_embeddings()
645
+ if rag_system.embeddings is None or rag_system.embeddings.size == 0:
646
+ logging.error("Test RAG: Embeddings not initialized properly.")
647
+ return
648
+ test_queries = {
649
+ "employee engagement": "engagement",
650
+ "hiring talent": "acquisition",
651
+ "company culture": "culture",
652
+ "linkedin posts": "linkedin"
653
+ }
654
+ all_tests_passed = True
655
+ for query, keyword in test_queries.items():
656
+ logging.info(f"Test RAG: Retrieving for query: '{query}'")
657
+ result = await rag_system.retrieve_relevant_info(query, top_k=1, min_similarity=0.1)
658
+ if result and keyword.lower() in result.lower():
659
+ logging.info(f"Test RAG: PASSED for query '{query}'. Found relevant doc.")
660
+ else:
661
+ logging.error(f"Test RAG: FAILED for query '{query}'. Expected keyword '{keyword}', got: {result[:100]}...")
662
+ all_tests_passed = False
663
+ if all_tests_passed: logging.info("All RAG retrieval accuracy tests passed.")
664
+ else: logging.error("Some RAG retrieval accuracy tests FAILED.")
665
+