GuglielmoTor commited on
Commit
514ad52
·
verified ·
1 Parent(s): 003ceb6

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +524 -408
eb_agent_module.py CHANGED
@@ -4,529 +4,645 @@ import os
4
  import asyncio
5
  import logging
6
  import numpy as np
7
- import textwrap
8
- from datetime import datetime # Added for date calculations
 
 
 
 
 
9
 
10
  try:
11
  from google import genai
12
- from google.genai import types # For GenerateContentConfig, SafetySetting, HarmCategory, HarmBlockThreshold etc.
 
 
 
 
 
13
  except ImportError:
14
- logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
15
- # Define dummy classes/variables if import fails
16
- class genai: Client = None # type: ignore
17
- class types: # type: ignore
18
- EmbedContentConfig = None
19
- GenerateContentConfig = None
 
 
 
 
 
 
 
 
20
  SafetySetting = None
21
- # Define HarmCategory and HarmBlockThreshold as inner classes or attributes for the dummy types
22
- class HarmCategory: # type: ignore
 
23
  HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
24
  HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
25
  HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
26
  HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
27
  HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
28
- class HarmBlockThreshold: # type: ignore
 
29
  BLOCK_NONE = "BLOCK_NONE"
30
  BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
31
  BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
32
- BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" # Added for completeness, adjust if needed
33
-
 
 
 
 
34
  # --- Custom Exceptions ---
35
  class ValidationError(Exception):
36
  """Custom validation error for agent inputs"""
37
  pass
38
 
39
- class RateLimitError(Exception):
40
  """Placeholder for rate limit errors."""
41
  pass
42
 
 
 
 
 
43
  # --- Configuration Constants ---
44
  GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
45
- if not GEMINI_API_KEY:
46
- logging.warning("GEMINI_API_KEY environment variable not set. EB Agent will not function.")
47
-
48
- LLM_MODEL_NAME = "gemini-1.5-flash-latest"
49
- GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004"
50
 
51
  GENERATION_CONFIG_PARAMS = {
52
  "temperature": 0.7,
53
  "top_p": 0.95,
54
  "top_k": 40,
55
- "max_output_tokens": 8192,
56
  "candidate_count": 1,
57
  }
58
 
59
- # No safety settings by default as per user request
60
- DEFAULT_SAFETY_SETTINGS = []
61
- logging.info("Default safety settings are now empty (no explicit client-side safety settings).")
62
-
63
 
64
- df_rag_documents = pd.DataFrame({
 
65
  'text': [
66
  "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
67
  "Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
68
  "LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
69
- "Analyzing follower demographics and post engagement helps refine employer branding strategies."
 
 
70
  ]
71
  })
72
 
73
  # --- Client Initialization ---
74
  client = None
75
- if GEMINI_API_KEY and genai.Client:
76
  try:
 
77
  client = genai.Client(api_key=GEMINI_API_KEY)
78
- logging.info("Google GenAI client initialized successfully.")
79
  except Exception as e:
80
- logging.error(f"Failed to initialize Google GenAI client: {e}", exc_info=True)
 
81
  else:
82
- logging.warning("Google GenAI client could not be initialized (GEMINI_API_KEY missing or library import failed).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
 
 
 
 
 
 
 
 
 
84
 
85
  class AdvancedRAGSystem:
86
  def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
87
- self.documents_df = documents_df.copy()
88
- self.embedding_model_name = embedding_model_name
89
- self.embeddings: np.ndarray | None = None
90
- logging.info(f"AdvancedRAGSystem initialized with embedding model: {self.embedding_model_name}")
91
-
92
- def _embed_single_document_sync(self, text: str) -> np.ndarray:
 
 
 
 
 
 
 
93
  if not client:
94
  raise ConnectionError("GenAI client not initialized for RAG embedding.")
95
  if not text or not isinstance(text, str):
96
- raise ValueError("Cannot embed empty or non-string text.")
 
97
 
98
- embed_config = None
99
- if types and hasattr(types, 'EmbedContentConfig'):
100
- embed_config = types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
101
-
102
- response = client.models.embed_content(
103
- model=self.embedding_model_name,
104
- contents=text,
105
- config=embed_config
106
- )
107
- return np.array(response.embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  async def initialize_embeddings(self):
110
- if self.documents_df.empty:
111
- logging.info("RAG documents DataFrame is empty. No embeddings to initialize.")
112
  self.embeddings = np.array([])
 
113
  return
114
- if not client:
 
115
  logging.error("GenAI client not available for RAG embedding initialization.")
116
  self.embeddings = np.array([])
117
  return
118
 
119
  logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
120
  embedded_docs_list = []
 
121
  for index, row in self.documents_df.iterrows():
122
- text_to_embed = row.get('text')
123
  if not text_to_embed or not isinstance(text_to_embed, str):
124
- logging.warning(f"Skipping document at index {index} due to invalid text: {text_to_embed}")
125
  continue
 
126
  try:
 
127
  embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
128
- embedded_docs_list.append(embedding_array)
 
 
 
129
  except Exception as e:
130
- logging.error(f"Error embedding document text (index {index}) '{str(text_to_embed)[:50]}...': {e}", exc_info=False)
 
131
 
132
  if not embedded_docs_list:
133
  self.embeddings = np.array([])
134
- logging.warning("No documents were successfully embedded for RAG.")
135
  else:
136
  try:
 
 
 
 
 
 
 
 
 
137
  self.embeddings = np.vstack(embedded_docs_list)
138
- logging.info(f"Successfully embedded {len(embedded_docs_list)} documents for RAG. Embedding matrix shape: {self.embeddings.shape}")
139
  except ValueError as ve:
140
- logging.error(f"Error stacking embeddings: {ve}. Check individual embedding errors.", exc_info=True)
141
  self.embeddings = np.array([])
 
 
 
142
 
143
  def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
144
- query_vector = query_vector.flatten()
 
 
 
 
 
 
 
 
145
  norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
146
- normalized_embeddings_matrix = embeddings_matrix / (norm_matrix + 1e-8)
147
- norm_query = np.linalg.norm(query_vector)
148
- normalized_query_vector = query_vector / (norm_query + 1e-8)
149
- return np.dot(normalized_embeddings_matrix, normalized_query_vector)
 
 
 
 
 
 
150
 
151
  async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
152
- if self.embeddings is None or self.embeddings.size == 0 or self.documents_df.empty:
153
- logging.debug("RAG system not initialized or no documents/embeddings available for retrieval.")
 
 
 
154
  return ""
155
  if not query or not isinstance(query, str):
156
  logging.debug("Empty or invalid query for RAG retrieval.")
157
  return ""
158
- if not client:
 
159
  logging.error("GenAI client not available for RAG query embedding.")
160
  return ""
161
 
162
  try:
163
- query_vector = await asyncio.to_thread(self._embed_single_document_sync, query)
164
- except Exception as e:
165
- logging.error(f"Error embedding query '{str(query)[:50]}...': {e}", exc_info=False)
166
- return ""
167
 
168
- if query_vector.ndim == 0 or query_vector.size == 0:
169
- logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}")
170
- return ""
171
-
172
- try:
173
  similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
174
- if similarity_scores.size == 0: return ""
175
- relevant_indices_after_threshold = np.where(similarity_scores >= min_similarity)[0]
176
- if len(relevant_indices_after_threshold) == 0:
177
- logging.debug(f"No documents met the minimum similarity threshold of {min_similarity} for query: {query[:50]}")
178
  return ""
179
- relevant_scores = similarity_scores[relevant_indices_after_threshold]
180
- sorted_relevant_indices_local = np.argsort(relevant_scores)[::-1]
181
- top_original_indices = relevant_indices_after_threshold[sorted_relevant_indices_local[:top_k]]
182
- if len(top_original_indices) == 0: return ""
183
- context_parts = [self.documents_df.iloc[i]['text'] for i in top_original_indices if 'text' in self.documents_df.columns]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  context = "\n\n---\n\n".join(context_parts)
185
- logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...")
186
  return context
 
187
  except Exception as e:
188
- logging.error(f"Error during RAG retrieval (similarity/sorting): {e}", exc_info=True)
189
  return ""
190
 
191
-
192
  class EmployerBrandingAgent:
193
  def __init__(self,
194
- all_dataframes: dict,
195
- rag_documents_df: pd.DataFrame,
196
- llm_model_name: str,
197
- embedding_model_name: str,
198
- generation_config_dict: dict,
199
- safety_settings_list: list,
200
- force_sandbox: bool = False):
201
- self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()}
202
- self.schemas_representation = self._get_enhanced_schemas_representation()
203
- self.chat_history = []
 
 
204
  self.llm_model_name = llm_model_name
205
- self.generation_config_dict = generation_config_dict
206
- # If an empty list is passed, it means no specific safety settings are enforced by the client.
207
- self.safety_settings_list = safety_settings_list if safety_settings_list is not None else []
208
- self.embedding_model_name = embedding_model_name
209
- self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
210
- self.force_sandbox = force_sandbox
211
- logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. Safety settings count: {len(self.safety_settings_list)}")
212
-
213
- def _get_date_range(self, df: pd.DataFrame) -> str:
214
- for col in df.columns:
215
- if pd.api.types.is_datetime64_any_dtype(df[col]):
216
- try:
217
- min_date = df[col].min()
218
- max_date = df[col].max()
219
- if pd.notna(min_date) and pd.notna(max_date):
220
- return f"{min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}"
221
- except Exception: pass
222
- return "N/A"
223
-
224
- def _calculate_growth_rate(self, df: pd.DataFrame) -> str:
225
- logging.debug("_calculate_growth_rate is a placeholder.")
226
- return "Growth rate calculation not implemented."
227
- def _analyze_engagement_trends(self, df: pd.DataFrame) -> str:
228
- logging.debug("_analyze_engagement_trends is a placeholder.")
229
- return "Engagement trend analysis not implemented."
230
- def _analyze_demographics(self, df: pd.DataFrame) -> str:
231
- logging.debug("_analyze_demographics is a placeholder.")
232
- return "Demographic analysis not implemented."
233
- def _analyze_post_performance(self, df: pd.DataFrame) -> str:
234
- logging.debug("_analyze_post_performance is a placeholder.")
235
- return "Post performance analysis not implemented."
236
- def _extract_content_themes(self, df: pd.DataFrame) -> str:
237
- logging.debug("_extract_content_themes is a placeholder.")
238
- return "Content theme extraction not implemented."
239
- def _find_optimal_times(self, df: pd.DataFrame) -> str:
240
- logging.debug("_find_optimal_times is a placeholder.")
241
- return "Optimal posting time analysis not implemented."
242
-
243
- def _calculate_key_metrics(self, df: pd.DataFrame, df_type: str) -> dict:
244
- metrics = {}
245
- if 'follower' in df_type.lower():
246
- metrics.update({'follower_growth_rate': self._calculate_growth_rate(df), 'engagement_trends': self._analyze_engagement_trends(df), 'demographic_distribution': self._analyze_demographics(df)})
247
- elif 'post' in df_type.lower():
248
- metrics.update({'post_performance': self._analyze_post_performance(df), 'content_themes': self._extract_content_themes(df), 'optimal_posting_times': self._find_optimal_times(df)})
249
- elif 'mention' in df_type.lower():
250
- metrics['mention_volume_trend'] = "Mention volume trend not implemented."
251
- metrics['mention_sentiment_overview'] = "Mention sentiment overview not implemented."
252
- if not metrics:
253
- logging.debug(f"No specific key metrics defined for df_type: {df_type}")
254
- return {"info": "Standard metrics applicable."}
255
- return metrics
256
-
257
- def _calculate_data_freshness(self, df: pd.DataFrame) -> str:
258
- for col in df.columns:
259
- if pd.api.types.is_datetime64_any_dtype(df[col]):
260
- try:
261
- max_date = df[col].max()
262
- if pd.notna(max_date):
263
- days_diff = (datetime.now(max_date.tzinfo if max_date.tzinfo else None) - max_date).days
264
- return f"Data up to {max_date.strftime('%Y-%m-%d')} ({days_diff} days old)"
265
- except Exception: pass
266
- return "Freshness N/A (no clear date column)"
267
- def _check_data_consistency(self, df: pd.DataFrame) -> str:
268
- logging.debug("_check_data_consistency is a placeholder.")
269
- return "Consistency checks not implemented."
270
- def _identify_accuracy_issues(self, df: pd.DataFrame) -> str:
271
- logging.debug("_identify_accuracy_issues is a placeholder.")
272
- return "Accuracy issue identification not implemented."
273
-
274
- def _assess_data_quality(self, df: pd.DataFrame) -> dict:
275
- completeness = (1 - (df.isnull().sum().sum() / (len(df) * len(df.columns)))) if len(df) > 0 and len(df.columns) > 0 else 0
276
- return {'completeness_score': f"{completeness:.2%}", 'freshness_info': self._calculate_data_freshness(df), 'consistency_check': self._check_data_consistency(df), 'accuracy_flags_summary': self._identify_accuracy_issues(df), 'sample_size_notes': f"{len(df)} records. {'Adequate for basic analysis.' if len(df) >= 100 else 'Limited sample size; insights may be indicative.'}"}
277
-
278
- def _identify_patterns(self, df: pd.DataFrame, key: str) -> str:
279
- logging.debug(f"_identify_patterns for {key} is a placeholder.")
280
- return "Pattern identification not implemented."
281
-
282
- def _format_df_analysis(self, df_key: str, analysis: dict) -> str:
283
- formatted_parts = [f"\n--- DataFrame: df_{df_key} ---", f" Shape: {analysis['shape']}", f" Date Range: {analysis['date_range']}", " Key Metrics:"]
284
- for metric, value in analysis['key_metrics'].items(): formatted_parts.append(f" - {metric.replace('_', ' ').title()}: {value}")
285
- formatted_parts.append(" Data Quality Assessment:")
286
- for aspect, value in analysis['data_quality'].items(): formatted_parts.append(f" - {aspect.replace('_', ' ').title()}: {value}")
287
- formatted_parts.append(f" Notable Patterns: {analysis['notable_patterns']}")
288
- return "\n".join(formatted_parts)
289
-
290
- def _get_enhanced_schemas_representation(self) -> str:
291
- schema_descriptions = ["=== DETAILED LINKEDIN DATA OVERVIEW ==="]
292
- if not self.all_dataframes:
293
- schema_descriptions.append("No dataframes available for analysis.")
294
- return "\n".join(schema_descriptions)
295
- for key, df in self.all_dataframes.items():
296
- if df.empty:
297
- schema_descriptions.append(f"\n--- DataFrame: df_{key} ---\nStatus: Empty. No analysis possible.")
298
- continue
299
- analysis = {'shape': df.shape, 'date_range': self._get_date_range(df), 'key_metrics': self._calculate_key_metrics(df, key), 'data_quality': self._assess_data_quality(df), 'notable_patterns': self._identify_patterns(df, key)}
300
- schema_descriptions.append(self._format_df_analysis(key, analysis))
301
- return "\n".join(schema_descriptions)
302
-
303
- def _extract_query_intent(self, query: str) -> str:
304
- logging.debug("_extract_query_intent is a placeholder.")
305
- if "compare" in query.lower() or "benchmark" in query.lower(): return "comparison"
306
- if "trend" in query.lower(): return "trend_analysis"
307
- return "general"
308
-
309
- async def _get_business_context(self, intent: str) -> str:
310
- logging.debug("_get_business_context is a placeholder.")
311
- if intent == "comparison": return "Company is focused on outperforming competitors in tech hiring."
312
- return "Company aims to improve overall employer brand perception."
313
-
314
- async def _get_industry_benchmarks(self, intent: str) -> str:
315
- logging.debug("_get_industry_benchmarks is a placeholder.")
316
- if intent == "trend_analysis": return "Typical follower growth in this sector is 5-10% MoM."
317
- return "Average engagement rate for similar companies is 2-3%."
318
-
319
- async def _enhance_rag_context(self, query: str, base_context: str) -> str:
320
- intent = self._extract_query_intent(query)
321
- business_context_val = await self._get_business_context(intent)
322
- benchmarks_val = await self._get_industry_benchmarks(intent)
323
- enhanced_context = f"""{base_context}
324
- --- ADDITIONAL CONTEXT FOR YOUR ANALYSIS ---
325
- Business Focus: {business_context_val}
326
- Relevant Benchmarks: {benchmarks_val}"""
327
- return enhanced_context
328
-
329
- async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
330
- prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation]
331
- if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
332
- base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
333
- if base_rag_context:
334
- enhanced_rag_context = await self._enhance_rag_context(raw_user_query, base_rag_context)
335
- prompt_parts.extend(["--- RELEVANT CONTEXTUAL INFORMATION (from documents & business knowledge) ---", enhanced_rag_context])
336
- prompt_parts.extend(["--- USER REQUEST ---", f"Based on all the information above, please respond to the following user query:\n{raw_user_query}"])
337
- final_prompt = "\n".join(prompt_parts)
338
- logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
339
- return final_prompt
340
-
341
- async def _process_structured_query(self, prompt: str) -> dict:
342
- logging.debug("_process_structured_query is a placeholder.")
343
- return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
344
-
345
- async def _generate_hr_insights(self, query: str, context: str) -> str:
346
- insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..."
347
- if not client: return "Error: AI client not configured for generating HR insights."
348
- api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
349
 
350
- api_safety_settings_objects = []
351
- # self.safety_settings_list is expected to be empty if no settings are desired
352
- if types and hasattr(types, 'SafetySetting') and self.safety_settings_list:
353
- for ss_item in self.safety_settings_list:
354
  try:
355
- api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
356
- except Exception as e_ss:
357
- logging.warning(f"Could not create SafetySetting object from {ss_item} for HR insights: {e_ss}. Using raw item.")
358
- api_safety_settings_objects.append(ss_item)
359
- elif self.safety_settings_list: # Fallback if types.SafetySetting not available but list is not empty
360
- api_safety_settings_objects = self.safety_settings_list
 
361
 
 
 
 
362
 
363
- api_generation_config_obj = None
364
- if types and hasattr(types, 'GenerateContentConfig'):
365
- api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
366
- else: # Fallback if types.GenerateContentConfig is not available
367
- api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
 
 
 
 
368
 
 
 
369
  try:
370
- response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj)
371
- if not response.candidates: return "HR insights generation failed: No response from AI."
372
- return response.text.strip()
 
 
 
 
 
373
  except Exception as e:
374
- logging.error(f"Error generating HR insights: {e}", exc_info=True)
375
- return f"Error generating HR insights: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  def _validate_query(self, query: str) -> bool:
378
- if not query or len(query.strip()) < 3: logging.warning(f"Query too short: '{query}'"); return False
379
- hr_keywords = ['employee', 'talent', 'hiring', 'culture', 'brand', 'engagement', 'retention', 'follower', 'post', 'mention', 'linkedin']
380
- if not any(keyword in query.lower() for keyword in hr_keywords): logging.warning(f"Query may not be HR/LinkedIn-relevant: {query[:50]}")
 
 
 
381
  return True
382
 
383
- def _get_query_help_message(self) -> str:
384
- return "I'm here to help with Employer Branding analysis... Example: 'What are the top industries of my followers?'"
385
-
386
- async def _check_system_readiness(self) -> dict:
387
- logging.debug("_check_system_readiness is a placeholder.")
388
- if not client: return {'ready': False, 'reason': 'AI Client not initialized.'}
389
- if self.rag_system.embeddings is None: logging.warning("RAG embeddings not yet initialized.")
390
- return {'ready': True, 'reason': 'System appears ready.'}
391
-
392
- def _get_fallback_response(self, query: str) -> str:
393
- logging.error(f"Executing fallback response for query: {query[:50]}")
394
- return "I encountered an unexpected issue..."
395
-
396
- async def _core_query_processing(self, raw_user_query_this_turn: str) -> str:
397
- augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn)
398
- api_call_contents = list(self.chat_history)
399
- api_call_contents.append({"role": "user", "parts": [{"text": augmented_current_user_prompt_text}]})
400
- logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
401
-
402
- api_safety_settings_objects = []
403
- # self.safety_settings_list is expected to be empty if no settings are desired
404
- if types and hasattr(types, 'SafetySetting') and self.safety_settings_list:
405
- for ss_item in self.safety_settings_list:
406
- try:
407
- api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
408
- except Exception as e_ss_core:
409
- logging.warning(f"Could not create SafetySetting object from {ss_item} in core: {e_ss_core}. Using raw item.")
410
- api_safety_settings_objects.append(ss_item)
411
- elif self.safety_settings_list : # Fallback if types.SafetySetting not available but list is not empty
412
- api_safety_settings_objects = self.safety_settings_list
413
-
414
-
415
- api_generation_config_obj = None
416
- if types and hasattr(types, 'GenerateContentConfig'):
417
- api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
418
- else: # Fallback if types.GenerateContentConfig is not available
419
- logging.error("GenerateContentConfig type not available. API call might fail.")
420
- api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
421
-
422
- response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj)
423
- if not response.candidates:
424
- block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
425
- block_message = response.prompt_feedback.block_reason_message if response.prompt_feedback else ""
426
- error_message = f"The AI's response was blocked. Reason: {block_reason}." + (f" Details: {block_message}" if block_message else "")
427
- return error_message
428
- return response.text.strip()
429
-
430
- async def _process_query_with_timeout(self, raw_user_query_this_turn: str, timeout_seconds: int = 60) -> str:
431
- try: return await asyncio.wait_for(self._core_query_processing(raw_user_query_this_turn), timeout=timeout_seconds)
432
- except asyncio.TimeoutError:
433
- logging.error(f"Query processing timed out for {timeout_seconds} seconds...")
434
- return "I'm sorry, but your request took too long..."
435
-
436
- async def process_query(self, raw_user_query_this_turn: str) -> str:
437
- if not client: return "Error: The AI Agent is not available..."
438
- if not self._validate_query(raw_user_query_this_turn): return self._get_query_help_message()
439
- readiness_check = await self._check_system_readiness()
440
- if not readiness_check['ready']: return f"System not ready: {readiness_check['reason']}"
441
- max_retries = 2
442
- for attempt in range(max_retries + 1):
443
- try:
444
- response_text = await self._process_query_with_timeout(raw_user_query_this_turn)
445
- if "The AI's response was blocked" in response_text: return response_text
446
- logging.info(f"Successfully received AI response (attempt {attempt+1}): {response_text[:100]}")
447
- return response_text
448
- except RateLimitError as rle:
449
- if attempt == max_retries: return "The AI service is currently busy..."
450
- await asyncio.sleep(2 ** attempt)
451
- except ValidationError as ve: return f"Query validation failed: {str(ve)}"
452
- except Exception as e:
453
- if attempt == max_retries: return self._get_fallback_response(raw_user_query_this_turn)
454
- return self._get_fallback_response(raw_user_query_this_turn)
455
 
456
- def _classify_query_type(self, query: str) -> str:
457
- query_lower = query.lower()
458
- if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
459
- elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
460
- elif any(word in query_lower for word in ['predict', 'forecast', 'future']): return 'predictive_analysis'
461
- elif any(word in query_lower for word in ['recommend', 'suggest', 'improve', 'advice', 'help me with']): return 'recommendation_engine'
462
- elif any(word in query_lower for word in ['what is', 'explain', 'define']): return 'definition_explanation'
463
- else: return 'general_inquiry'
464
 
 
 
 
 
 
 
465
 
466
  def clear_chat_history(self):
 
467
  self.chat_history = []
468
- logging.info("EmployerBrandingAgent chat history cleared by request.")
469
-
470
- def get_all_schemas_representation(all_dataframes: dict) -> str:
471
- if not all_dataframes: return "No DataFrames are currently loaded."
472
- schema_descriptions = ["DataFrames currently available in the application state:"]
473
- for key, df in all_dataframes.items():
474
- df_name = f"df_{key}"
475
- columns = ", ".join(df.columns)
476
- shape = df.shape
477
- if df.empty:
478
- schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
479
- else:
480
- try:
481
- sample_data_str = df.head(2).to_markdown(index=False)
482
- except ImportError:
483
- logging.warning("`tabulate` library not found. Falling back to `to_string()` for schema representation.")
484
- sample_data_str = df.head(2).to_string(index=False)
485
- except Exception as e:
486
- logging.error(f"Error formatting DataFrame sample for {df_name} with to_markdown: {e}. Falling back to to_string().")
487
- sample_data_str = df.head(2).to_string(index=False)
488
-
489
- schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns: {columns}\n\n<details><summary>Sample Data (first 2 rows of {df_name}):</summary>\n\n```text\n{sample_data_str}\n```\n\n</details>")
490
- schema_descriptions.append(schema)
491
- return "\n".join(schema_descriptions)
492
-
493
-
494
- async def test_rag_retrieval_accuracy():
495
- logging.info("Running RAG retrieval accuracy test...")
496
- test_embedding_model = GEMINI_EMBEDDING_MODEL_NAME
497
- if not client:
498
- logging.error("Cannot run RAG test: GenAI client not initialized.")
499
- return
500
- test_docs_data = {
501
- 'text': [
502
- 'Strategies for improving employee engagement include regular feedback and recognition programs.',
503
- 'Effective talent acquisition requires a strong employer brand and a streamlined hiring process.',
504
- 'Company culture is a key driver of employee satisfaction and retention.',
505
- 'Analyzing LinkedIn post performance can reveal insights into content effectiveness.'
506
- ]
507
- }
508
- test_docs_df = pd.DataFrame(test_docs_data)
509
- rag_system = AdvancedRAGSystem(test_docs_df, test_embedding_model)
510
- logging.info("Test RAG: Initializing embeddings...")
511
- await rag_system.initialize_embeddings()
512
- if rag_system.embeddings is None or rag_system.embeddings.size == 0:
513
- logging.error("Test RAG: Embeddings not initialized properly.")
514
- return
515
- test_queries = {
516
- "employee engagement": "engagement",
517
- "hiring talent": "acquisition",
518
- "company culture": "culture",
519
- "linkedin posts": "linkedin"
520
- }
521
- all_tests_passed = True
522
- for query, keyword in test_queries.items():
523
- logging.info(f"Test RAG: Retrieving for query: '{query}'")
524
- result = await rag_system.retrieve_relevant_info(query, top_k=1, min_similarity=0.1)
525
- if result and keyword.lower() in result.lower():
526
- logging.info(f"Test RAG: PASSED for query '{query}'. Found relevant doc.")
527
- else:
528
- logging.error(f"Test RAG: FAILED for query '{query}'. Expected keyword '{keyword}', got: {result[:100]}...")
529
- all_tests_passed = False
530
- if all_tests_passed: logging.info("All RAG retrieval accuracy tests passed.")
531
- else: logging.error("Some RAG retrieval accuracy tests FAILED.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
 
 
4
  import asyncio
5
  import logging
6
  import numpy as np
7
+ import textwrap # Not used, but kept from original
8
+ from datetime import datetime # Not used, but kept from original
9
+ from typing import Dict, List, Optional, Union, Any
10
+ import traceback
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
14
 
15
  try:
16
  from google import genai
17
+ from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc.
18
+ # If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
19
+ # For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
20
+ # and EmbedContentConfig might be implicit or part of task_type.
21
+ GENAI_AVAILABLE = True
22
+ logging.info("Google Generative AI library imported successfully.")
23
  except ImportError:
24
+ logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai")
25
+ GENAI_AVAILABLE = False
26
+
27
+ # Dummy classes for graceful degradation (simplified)
28
+ class genai:
29
+ Client = None
30
+ # If using google-generativeai, these would be different:
31
+ # GenerativeModel = None
32
+ # def configure(*args, **kwargs): pass
33
+ # def embed_content(*args, **kwargs): return {}
34
+
35
+ class types: # Placeholder for types used in the original code
36
+ EmbedContentConfig = None # Placeholder
37
+ GenerationConfig = None # Placeholder
38
  SafetySetting = None
39
+ Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) # Dummy for FinishReason
40
+
41
+ class HarmCategory:
42
  HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
43
  HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
44
  HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
45
  HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
46
  HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
47
+
48
+ class HarmBlockThreshold:
49
  BLOCK_NONE = "BLOCK_NONE"
50
  BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
51
  BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
52
+ BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
53
+
54
+ class generation_types: # Dummy for BlockedPromptException
55
+ BlockedPromptException = type('BlockedPromptException', (Exception,), {})
56
+
57
+
58
  # --- Custom Exceptions ---
59
  class ValidationError(Exception):
60
  """Custom validation error for agent inputs"""
61
  pass
62
 
63
+ class RateLimitError(Exception): # Not used, but kept
64
  """Placeholder for rate limit errors."""
65
  pass
66
 
67
+ class AgentNotReadyError(Exception):
68
+ """Agent is not properly initialized"""
69
+ pass
70
+
71
  # --- Configuration Constants ---
72
  GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
73
+ LLM_MODEL_NAME = "gemini-1.5-flash-latest" # For google-generativeai, model name is directly used.
74
+ # For client.models.generate_content, it might need "models/gemini-1.5-flash-latest"
75
+ GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Similarly, might need "models/text-embedding-004"
 
 
76
 
77
  GENERATION_CONFIG_PARAMS = {
78
  "temperature": 0.7,
79
  "top_p": 0.95,
80
  "top_k": 40,
81
+ "max_output_tokens": 8192, # Ensure this is supported
82
  "candidate_count": 1,
83
  }
84
 
85
+ DEFAULT_SAFETY_SETTINGS = [] # User can populate this with {'category': HarmCategory.HARM_CATEGORY_X, 'threshold': HarmBlockThreshold.BLOCK_Y}
 
 
 
86
 
87
+ # Default RAG documents
88
+ DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
89
  'text': [
90
  "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
91
  "Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
92
  "LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
93
+ "Analyzing follower demographics and post engagement helps refine employer branding strategies.",
94
+ "Content strategy should align with company values to attract the right talent.",
95
+ "Employee advocacy programs can significantly boost employer brand reach and authenticity."
96
  ]
97
  })
98
 
99
  # --- Client Initialization ---
100
  client = None
101
+ if GEMINI_API_KEY and GENAI_AVAILABLE:
102
  try:
103
+ # This is specific. If using google-generativeai, this would be genai.configure(api_key=...)
104
  client = genai.Client(api_key=GEMINI_API_KEY)
105
+ logging.info("Google GenAI client initialized successfully (using genai.Client).")
106
  except Exception as e:
107
+ logging.error(f"Failed to initialize Google GenAI client (using genai.Client): {e}")
108
+ client = None
109
  else:
110
+ if not GEMINI_API_KEY:
111
+ logging.warning("GEMINI_API_KEY environment variable not set.")
112
+ if not GENAI_AVAILABLE:
113
+ logging.warning("Google GenAI library not available.")
114
+
115
+
116
+ # --- Utility function to get DataFrame schema representation ---
117
+ def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str:
118
+ """Generates a string representation of a DataFrame's schema and a small sample."""
119
+ if not isinstance(df, pd.DataFrame):
120
+ return f"Item '{df_name}' is not a DataFrame.\n"
121
+ if df.empty:
122
+ return f"DataFrame '{df_name}': Empty\n"
123
+
124
+ schema_parts = [f"DataFrame '{df_name}':"]
125
+ schema_parts.append(f" Shape: {df.shape}")
126
+ schema_parts.append(" Columns:")
127
+ for col in df.columns:
128
+ col_type = str(df[col].dtype)
129
+ null_count = df[col].isnull().sum()
130
+ unique_count = df[col].nunique()
131
+ schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})")
132
+
133
+ if not df.empty:
134
+ schema_parts.append(" Sample Data (first 2 rows):")
135
+ try:
136
+ sample_df_str = df.head(2).to_string(index=True, max_colwidth=50) # Show index for context
137
+ indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')])
138
+ schema_parts.append(indented_sample_df)
139
+ except Exception as e:
140
+ schema_parts.append(f" Could not generate sample data: {e}")
141
+
142
+ return "\n".join(schema_parts) + "\n"
143
 
144
+ def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str:
145
+ """Generates a string representation of all DataFrame schemas."""
146
+ if not dataframes:
147
+ return "No DataFrames available to the agent."
148
+
149
+ full_representation = ["=== Available DataFrame Schemas for Analysis ==="]
150
+ for name, df_instance in dataframes.items():
151
+ full_representation.append(get_df_schema_representation(df_instance, name))
152
+ return "\n".join(full_representation)
153
 
154
  class AdvancedRAGSystem:
155
  def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
156
+ self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy()
157
+ # Ensure 'text' column exists
158
+ if 'text' not in self.documents_df.columns and not self.documents_df.empty:
159
+ logging.warning("'text' column not found in RAG documents. RAG might not work.")
160
+ # Create an empty text column if df is not empty but lacks it, to prevent errors later
161
+ self.documents_df['text'] = ""
162
+
163
+ self.embedding_model_name = embedding_model_name # e.g., "models/text-embedding-004" or just "text-embedding-004"
164
+ self.embeddings: Optional[np.ndarray] = None
165
+ self.is_initialized = False
166
+ logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")
167
+
168
+ def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]:
169
  if not client:
170
  raise ConnectionError("GenAI client not initialized for RAG embedding.")
171
  if not text or not isinstance(text, str):
172
+ logging.warning("Cannot embed empty or non-string text for RAG.")
173
+ return None
174
 
175
+ try:
176
+ # Standard google-generativeai call:
177
+ # embedding_response = genai.embed_content(
178
+ # model=self.embedding_model_name, # e.g., "models/text-embedding-004"
179
+ # content=text,
180
+ # task_type="RETRIEVAL_DOCUMENT" # or "SEMANTIC_SIMILARITY"
181
+ # )
182
+ # return np.array(embedding_response['embedding'])
183
+
184
+ # Using the provided client.models.embed_content structure:
185
+ # This might require specific types for config.
186
+ embed_config_payload = None
187
+ if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): # Assuming types.EmbedContentConfig is relevant
188
+ # The task_type for EmbedContentConfig might differ, e.g., "SEMANTIC_SIMILARITY" or "RETRIEVAL_DOCUMENT"
189
+ embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
190
+
191
+
192
+ response = client.models.embed_content( # This is the user's original call structure
193
+ model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
194
+ contents=text, # Original used 'contents', genai.embed_content uses 'content'
195
+ config=embed_config_payload # Original passed 'config'
196
+ )
197
+
198
+ # Adapt response parsing based on actual client.models.embed_content behavior
199
+ if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
200
+ # This structure `response.embeddings[0]` seems specific.
201
+ # Standard genai.embed_content returns a dict `{'embedding': [values]}`
202
+ return np.array(response.embeddings[0])
203
+ elif hasattr(response, 'embedding'): # Common for genai.embed_content
204
+ return np.array(response.embedding)
205
+ else:
206
+ logging.error(f"Unexpected embedding response format: {response}")
207
+ return None
208
+ except Exception as e:
209
+ logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True)
210
+ raise
211
 
212
  async def initialize_embeddings(self):
213
+ if self.documents_df.empty or 'text' not in self.documents_df.columns:
214
+ logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
215
  self.embeddings = np.array([])
216
+ self.is_initialized = True # Initialized, but with no embeddings
217
  return
218
+
219
+ if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): # Check if standard genai can be used
220
  logging.error("GenAI client not available for RAG embedding initialization.")
221
  self.embeddings = np.array([])
222
  return
223
 
224
  logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
225
  embedded_docs_list = []
226
+
227
  for index, row in self.documents_df.iterrows():
228
+ text_to_embed = row.get('text', '')
229
  if not text_to_embed or not isinstance(text_to_embed, str):
230
+ logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.")
231
  continue
232
+
233
  try:
234
+ # Use asyncio.to_thread for the synchronous embedding call
235
  embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
236
+ if embedding_array is not None and embedding_array.size > 0:
237
+ embedded_docs_list.append(embedding_array)
238
+ else:
239
+ logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
240
  except Exception as e:
241
+ logging.error(f"Error embedding RAG document at index {index}: {e}")
242
+ continue # Continue with other documents
243
 
244
  if not embedded_docs_list:
245
  self.embeddings = np.array([])
246
+ logging.warning("No RAG documents were successfully embedded.")
247
  else:
248
  try:
249
+ # Ensure all embeddings have the same shape before vstack
250
+ first_shape = embedded_docs_list[0].shape
251
+ if not all(emb.shape == first_shape for emb in embedded_docs_list):
252
+ logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
253
+ # Attempt to filter out malformed embeddings if possible, or fail
254
+ # For now, we'll fail stacking if shapes are inconsistent.
255
+ self.embeddings = np.array([])
256
+ return # Exit if shapes are inconsistent
257
+
258
  self.embeddings = np.vstack(embedded_docs_list)
259
+ logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
260
  except ValueError as ve:
261
+ logging.error(f"Error stacking embeddings (likely due to inconsistent shapes): {ve}")
262
  self.embeddings = np.array([])
263
+
264
+ self.is_initialized = True
265
+
266
 
267
  def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
268
+ if embeddings_matrix.ndim == 1: # Handle case of single document embedding
269
+ embeddings_matrix = embeddings_matrix.reshape(1, -1)
270
+ if query_vector.ndim == 1:
271
+ query_vector = query_vector.reshape(1, -1)
272
+
273
+ if embeddings_matrix.size == 0 or query_vector.size == 0:
274
+ return np.array([])
275
+
276
+ # Normalize embeddings_matrix rows
277
  norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
278
+ # Add a small epsilon to avoid division by zero for zero vectors
279
+ normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
280
+
281
+ # Normalize query_vector
282
+ norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
283
+ normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
284
+
285
+ # Calculate dot product
286
+ return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()
287
+
288
 
289
  async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
290
+ if not self.is_initialized:
291
+ logging.debug("RAG system not initialized. Cannot retrieve info.")
292
+ return ""
293
+ if self.embeddings is None or self.embeddings.size == 0:
294
+ logging.debug("RAG embeddings not available. Cannot retrieve info.")
295
  return ""
296
  if not query or not isinstance(query, str):
297
  logging.debug("Empty or invalid query for RAG retrieval.")
298
  return ""
299
+
300
+ if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
301
  logging.error("GenAI client not available for RAG query embedding.")
302
  return ""
303
 
304
  try:
305
+ query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) # Embed query
306
+ if query_vector is None or query_vector.size == 0:
307
+ logging.warning("Query vector embedding failed or is empty for RAG.")
308
+ return ""
309
 
 
 
 
 
 
310
  similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
311
+ if similarity_scores.size == 0:
 
 
 
312
  return ""
313
+
314
+ relevant_indices = np.where(similarity_scores >= min_similarity)[0]
315
+ if len(relevant_indices) == 0:
316
+ logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
317
+ return ""
318
+
319
+ # Get scores for relevant documents and sort
320
+ relevant_scores = similarity_scores[relevant_indices]
321
+ # Argsort returns indices to sort relevant_scores; apply to relevant_indices
322
+ sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
323
+
324
+ top_indices = sorted_relevant_indices_of_original[:top_k]
325
+
326
+ context_parts = []
327
+ if 'text' in self.documents_df.columns:
328
+ for i in top_indices:
329
+ if 0 <= i < len(self.documents_df):
330
+ context_parts.append(self.documents_df.iloc[i]['text'])
331
+
332
  context = "\n\n---\n\n".join(context_parts)
333
+ logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'")
334
  return context
335
+
336
  except Exception as e:
337
+ logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
338
  return ""
339
 
 
340
  class EmployerBrandingAgent:
341
  def __init__(self,
342
+ all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
343
+ rag_documents_df: Optional[pd.DataFrame] = None,
344
+ llm_model_name: str = LLM_MODEL_NAME,
345
+ embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
346
+ generation_config_dict: Optional[Dict] = None,
347
+ safety_settings_list: Optional[List] = None): # safety_settings_list expects list of dicts or SafetySetting objects
348
+
349
+ self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} # Deep copy
350
+
351
+ _rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
352
+ self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
353
+
354
  self.llm_model_name = llm_model_name
355
+ self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+ # Ensure safety settings are in the correct format if using google-generativeai directly
358
+ self.safety_settings_list = []
359
+ if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'):
360
+ for ss_dict in safety_settings_list:
361
  try:
362
+ # Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC}
363
+ self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
364
+ except Exception as e:
365
+ logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}")
366
+ elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models)
367
+ self.safety_settings_list = safety_settings_list
368
+
369
 
370
+ self.chat_history: List[Dict[str, str]] = [] # Stores {"role": "user/model", "content": "..."}
371
+ self.is_ready = False
372
+ self.llm_model_instance = None # For google-generativeai
373
 
374
+ if GENAI_AVAILABLE and client is None and GEMINI_API_KEY: # If client.Client failed but standard genai can be used
375
+ try:
376
+ genai.configure(api_key=GEMINI_API_KEY)
377
+ self.llm_model_instance = genai.GenerativeModel(self.llm_model_name)
378
+ logging.info(f"Initialized GenerativeModel '{self.llm_model_name}' via google-generativeai.")
379
+ except Exception as e:
380
+ logging.error(f"Failed to initialize google-generativeai.GenerativeModel: {e}")
381
+
382
+ logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
383
 
384
+ async def initialize(self) -> bool:
385
+ """Initializes asynchronous components of the agent, primarily RAG embeddings."""
386
  try:
387
+ if not client and not self.llm_model_instance : # Check if any LLM access is configured
388
+ logging.error("Cannot initialize agent: GenAI client (client.Client or google.generativeai) not available/configured.")
389
+ return False
390
+
391
+ await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized
392
+ self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs)
393
+ logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}")
394
+ return True
395
  except Exception as e:
396
+ logging.error(f"Error during EmployerBrandingAgent.initialize: {e}", exc_info=True)
397
+ self.is_ready = False
398
+ return False
399
+
400
+ def _get_dataframes_summary(self) -> str:
401
+ return get_all_schemas_representation(self.all_dataframes)
402
+
403
+ def _build_system_prompt(self) -> str:
404
+ # This prompt provides overall guidance to the LLM.
405
+ return textwrap.dedent("""
406
+ You are an expert Employer Branding Analyst AI. Your primary function is to analyze LinkedIn data provided (follower statistics, post performance, mentions) and offer actionable insights, data-driven recommendations, and if requested, Python Pandas code snippets for further analysis.
407
+
408
+ When providing insights or recommendations:
409
+ - Be specific and base your conclusions on the data summaries and context provided.
410
+ - Structure responses clearly, perhaps using bullet points for key findings or actions.
411
+ - Focus on practical advice that can help improve employer branding efforts.
412
+
413
+ When asked to generate Pandas code:
414
+ - Assume the data is available in pandas DataFrames named exactly as in the 'Available DataFrame Schemas' section (e.g., `df_follower_stats`, `df_posts`).
415
+ - Generate executable Python code using pandas.
416
+ - Ensure the code is directly relevant to the user's query and the available data.
417
+ - Briefly explain what the code does.
418
+ - If a query implies data not present in the schemas, state that and do not attempt to fabricate code for it.
419
+ - Do not generate code that modifies DataFrames in place unless explicitly asked. Prefer returning new DataFrames or Series.
420
+ - Handle potential errors in data (e.g., missing values if relevant to the operation) gracefully if simple to do so.
421
+ - Output the code in a single, copy-pasteable block.
422
+
423
+ Always refer to the provided DataFrame schemas to understand available columns and data types. Do not hallucinate columns or data.
424
+ If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
425
+ """).strip()
426
+
427
+ async def _generate_response(self, current_user_query: str) -> str:
428
+ """
429
+ Generates a response from the LLM based on the current query, system prompts,
430
+ data summaries, RAG context, and the agent's chat history.
431
+ Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
432
+ """
433
+ if not self.is_ready:
434
+ return "Agent is not ready. Please initialize."
435
+ if not client and not self.llm_model_instance:
436
+ return "Error: AI service is not available. Check API configuration."
437
+
438
+ try:
439
+ system_prompt_text = self._build_system_prompt()
440
+ data_summary_text = self._get_dataframes_summary()
441
+ rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25) # Fine-tuned RAG params
442
+
443
+ # Construct the messages for the LLM API call
444
+ # The history (self.chat_history) is set by app.py and includes the current user query.
445
+ llm_messages = []
446
+
447
+ # 1. System-level instructions and context (as a first "user" turn)
448
+ initial_context_prompt = (
449
+ f"{system_prompt_text}\n\n"
450
+ f"## Available Data Overview:\n{data_summary_text}\n\n"
451
+ f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
452
+ f"Given this context, please respond to the user queries that follow in the chat history."
453
+ )
454
+ llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
455
+ # 2. Priming assistant message
456
+ llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
457
+
458
+ # 3. Append the actual conversation history (already includes the current user query)
459
+ for entry in self.chat_history: # self.chat_history is set by app.py
460
+ llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
461
+
462
+ # Prepare generation config and safety settings for the API
463
+ gen_config_payload = self.generation_config_dict
464
+ safety_settings_payload = self.safety_settings_list # Already formatted if types.SafetySetting used
465
+
466
+ if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
467
+ try:
468
+ gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
469
+ except Exception as e:
470
+ logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
471
+
472
+
473
+ # --- Make the API call ---
474
+ response_text = ""
475
+ if self.llm_model_instance: # Standard google-generativeai usage
476
+ logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
477
+ api_response = await self.llm_model_instance.generate_content_async(
478
+ contents=llm_messages,
479
+ generation_config=gen_config_payload,
480
+ safety_settings=safety_settings_payload
481
+ )
482
+ response_text = api_response.text # Simplification, assumes single part text response
483
+ elif client: # User's original client.models.generate_content structure
484
+ logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
485
+ # This call needs to be async or wrapped, asyncio.to_thread is used as in original
486
+ model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
487
+ api_response = await asyncio.to_thread(
488
+ client.models.generate_content,
489
+ model=model_path,
490
+ contents=llm_messages,
491
+ generation_config=gen_config_payload, # Ensure this is the correct type for client.models
492
+ safety_settings=safety_settings_payload # Ensure this is the correct type
493
+ )
494
+ # Parse response from client.models structure
495
+ if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
496
+ response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
497
+ response_text = "".join(response_text_parts).strip()
498
+ else: # Handle blocked or empty responses from client.models
499
+ if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
500
+ logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
501
+ return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
502
+ if api_response.candidates and api_response.candidates[0].finish_reason != types.Candidate.FinishReason.STOP: # Assuming types.Candidate.FinishReason.STOP is valid
503
+ logging.warning(f"Content generation stopped by client.models due to: {api_response.candidates[0].finish_reason}. Safety: {api_response.candidates[0].safety_ratings if hasattr(api_response.candidates[0], 'safety_ratings') else 'N/A'}")
504
+ return f"I couldn't complete the response. Reason: {api_response.candidates[0].finish_reason}. Please try rephrasing."
505
+ return "I apologize, but I couldn't generate a response from client.models."
506
+
507
+ else:
508
+ raise ConnectionError("No valid LLM client or model instance available.")
509
+
510
+ return response_text.strip()
511
+
512
+ except types.generation_types.BlockedPromptException as bpe: # Specific exception for google-generativeai
513
+ logging.error(f"BlockedPromptException from LLM: {bpe}", exc_info=True)
514
+ return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {bpe}"
515
+ except Exception as e:
516
+ logging.error(f"Error in _generate_response: {e}", exc_info=True)
517
+ return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
518
+
519
 
520
  def _validate_query(self, query: str) -> bool:
521
+ if not query or not isinstance(query, str) or len(query.strip()) < 3:
522
+ logging.warning(f"Invalid query: too short or not a string. Query: '{query}'")
523
+ return False
524
+ if len(query) > 3000: # Increased limit slightly
525
+ logging.warning(f"Invalid query: too long. Length: {len(query)}")
526
+ return False
527
  return True
528
 
529
+ async def process_query(self, user_query: str) -> str:
530
+ """
531
+ Processes the user's query.
532
+ It relies on self.chat_history being set externally (by app.py) to include the full
533
+ conversation context, including the current user_query as the last "user" message.
534
+ This method then calls _generate_response to get the AI's reply.
535
+ It does NOT modify self.chat_history itself; app.py is responsible for that based on Gradio state.
536
+ """
537
+ if not self._validate_query(user_query):
538
+ # This user_query is the one from Gradio input, also the last one in self.chat_history
539
+ return "Please provide a valid query (3 to 3000 characters)."
540
+
541
+ if not self.is_ready:
542
+ logging.warning("process_query called but agent is not ready. Attempting re-initialization.")
543
+ # This is a fallback. Ideally, initialize is called once and confirmed.
544
+ init_success = await self.initialize()
545
+ if not init_success:
546
+ return "The agent is not properly initialized and could not be started. Please check configuration and logs."
547
+
548
+ # user_query is the current text from the input box.
549
+ # self.chat_history (set by app.py) should already contain this user_query as the last message.
550
+ # We pass user_query to _generate_response primarily for RAG context retrieval for the current turn.
551
+ response_text = await self._generate_response(user_query)
552
+ return response_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
 
 
 
 
 
 
 
 
554
 
555
+ def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]):
556
+ """Updates the agent's DataFrames. Does not automatically re-initialize RAG or LLM."""
557
+ self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} # Deep copy
558
+ logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}")
559
+ # Note: If RAG documents depend on these DataFrames, RAG might need re-initialization.
560
+ # For now, RAG uses a static document set.
561
 
562
  def clear_chat_history(self):
563
+ """Clears the agent's internal chat history. App.py should also clear Gradio state."""
564
  self.chat_history = []
565
+ logging.info("EmployerBrandingAgent internal chat history cleared.")
566
+
567
+ def get_status(self) -> Dict[str, Any]:
568
+ return {
569
+ "is_ready": self.is_ready,
570
+ "has_api_key": bool(GEMINI_API_KEY),
571
+ "genai_available": GENAI_AVAILABLE,
572
+ "client_type": "genai.Client" if client else ("google-generativeai" if self.llm_model_instance else "None"),
573
+ "rag_initialized": self.rag_system.is_initialized,
574
+ "num_dataframes": len(self.all_dataframes),
575
+ "dataframe_keys": list(self.all_dataframes.keys()),
576
+ "num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
577
+ "llm_model_name": self.llm_model_name,
578
+ "embedding_model_name": self.embedding_model_name
579
+ }
580
+
581
+ # --- Functions for Gradio integration (if needed directly, but app.py handles instantiation) ---
582
+ def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
583
+ rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent:
584
+ logging.info("Creating new EmployerBrandingAgent instance via helper function.")
585
+ return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs)
586
+
587
+ async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool:
588
+ logging.info("Initializing agent via async helper function.")
589
+ return await agent.initialize()
590
+
591
+
592
+ if __name__ == "__main__":
593
+ async def test_agent_logic():
594
+ print("--- Testing Employer Branding Agent ---")
595
+ if not GEMINI_API_KEY:
596
+ print("GEMINI_API_KEY not set. Skipping live API tests.")
597
+ return
598
+
599
+ sample_dfs = {
600
+ "followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}),
601
+ "posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]})
602
+ }
603
+
604
+ # Test RAG document loading
605
+ custom_rag = pd.DataFrame({'text': ["Custom RAG context about LinkedIn engagement."]})
606
+
607
+ agent = EmployerBrandingAgent(
608
+ all_dataframes=sample_dfs,
609
+ rag_documents_df=custom_rag,
610
+ llm_model_name=LLM_MODEL_NAME,
611
+ embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME
612
+ )
613
+ print("Agent Status (pre-init):", agent.get_status())
614
+
615
+ init_success = await agent.initialize()
616
+ print(f"Agent Initialization Success: {init_success}")
617
+ print("Agent Status (post-init):", agent.get_status())
618
+
619
+ if not init_success:
620
+ print("Agent initialization failed. Cannot proceed with query test.")
621
+ return
622
+
623
+ # Simulate app.py setting history
624
+ test_query1 = "What are the key columns in my followers data?"
625
+ agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this
626
+
627
+ print(f"\nProcessing Query 1: '{test_query1}'")
628
+ response1 = await agent.process_query(user_query=test_query1) # Pass current query for RAG etc.
629
+ print(f"Agent Response 1:\n{response1}")
630
+
631
+ # Simulate app.py updating history for next turn
632
+ agent.chat_history.append({"role": "model", "content": response1})
633
+
634
+ test_query2 = "Generate pandas code to get the total follower count."
635
+ agent.chat_history.append({"role": "user", "content": test_query2})
636
+
637
+ print(f"\nProcessing Query 2: '{test_query2}'")
638
+ response2 = await agent.process_query(user_query=test_query2)
639
+ print(f"Agent Response 2:\n{response2}")
640
+
641
+ agent.chat_history.append({"role": "model", "content": response2})
642
+ print("\nFinal Agent Chat History (internal):")
643
+ for item in agent.chat_history:
644
+ print(f"- {item['role']}: {item['content'][:100]}...")
645
+
646
+ print("\n--- Test Complete ---")
647
 
648
+ asyncio.run(test_agent_logic())