GuglielmoTor commited on
Commit
09757d6
·
verified ·
1 Parent(s): 25e22b8

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +4 -216
eb_agent_module.py CHANGED
@@ -80,18 +80,6 @@ GENERATION_CONFIG_PARAMS = {
80
 
81
  DEFAULT_SAFETY_SETTINGS = []
82
 
83
- # Default RAG documents
84
- DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
85
- 'text': [
86
- "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
87
- "Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
88
- "LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
89
- "Analyzing follower demographics and post engagement helps refine employer branding strategies.",
90
- "Content strategy should align with company values to attract the right talent.",
91
- "Employee advocacy programs can significantly boost employer brand reach and authenticity."
92
- ]
93
- })
94
-
95
  # --- Client Initialization ---
96
  client = None
97
  if GEMINI_API_KEY and GENAI_AVAILABLE:
@@ -163,189 +151,11 @@ def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str:
163
  full_representation.append(get_df_schema_representation(df_instance, name))
164
  return "\n".join(full_representation)
165
 
166
- class AdvancedRAGSystem:
167
- def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
168
- self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy()
169
- # Ensure 'text' column exists
170
- if 'text' not in self.documents_df.columns and not self.documents_df.empty:
171
- logging.warning("'text' column not found in RAG documents. RAG might not work.")
172
- self.documents_df['text'] = ""
173
-
174
- self.embedding_model_name = embedding_model_name
175
- self.embeddings: Optional[np.ndarray] = None
176
- self.is_initialized = False
177
- logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")
178
-
179
- def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]:
180
- if not client:
181
- raise ConnectionError("GenAI client not initialized for RAG embedding.")
182
- if not text or not isinstance(text, str):
183
- logging.warning("Cannot embed empty or non-string text for RAG.")
184
- return None
185
-
186
- try:
187
- embed_config_payload = None
188
- if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
189
- embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
190
-
191
- response = client.models.embed_content(
192
- model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
193
- contents=text,
194
- config=embed_config_payload
195
- )
196
-
197
- # Fix: Handle ContentEmbedding objects properly
198
- if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
199
- embedding_obj = response.embeddings[0]
200
-
201
- # Extract values from ContentEmbedding object
202
- if hasattr(embedding_obj, 'values'):
203
- embedding_values = embedding_obj.values
204
- elif hasattr(embedding_obj, 'embedding'):
205
- embedding_values = embedding_obj.embedding
206
- elif isinstance(embedding_obj, (list, tuple)):
207
- embedding_values = embedding_obj
208
- else:
209
- # Try to convert to list/array if it's a different object type
210
- try:
211
- embedding_values = list(embedding_obj)
212
- except:
213
- logging.error(f"Cannot extract embedding values from object type: {type(embedding_obj)}")
214
- return None
215
-
216
- return np.array(embedding_values, dtype=np.float32)
217
- else:
218
- logging.error(f"Unexpected response structure")
219
- return None
220
-
221
- except Exception as e:
222
- logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True)
223
- raise
224
-
225
- async def initialize_embeddings(self):
226
- if self.documents_df.empty or 'text' not in self.documents_df.columns:
227
- logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
228
- self.embeddings = np.array([])
229
- self.is_initialized = True
230
- return
231
-
232
- if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
233
- logging.error("GenAI client not available for RAG embedding initialization.")
234
- self.embeddings = np.array([])
235
- return
236
-
237
- logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
238
- embedded_docs_list = []
239
-
240
- for index, row in self.documents_df.iterrows():
241
- text_to_embed = row.get('text', '')
242
- if not text_to_embed or not isinstance(text_to_embed, str):
243
- logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.")
244
- continue
245
-
246
- try:
247
- embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
248
- if embedding_array is not None and embedding_array.size > 0:
249
- embedded_docs_list.append(embedding_array)
250
- else:
251
- logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
252
- except Exception as e:
253
- logging.error(f"Error embedding RAG document at index {index}: {e}")
254
- continue
255
 
256
- if not embedded_docs_list:
257
- self.embeddings = np.array([])
258
- logging.warning("No RAG documents were successfully embedded.")
259
- else:
260
- try:
261
- first_shape = embedded_docs_list[0].shape
262
- if not all(emb.shape == first_shape for emb in embedded_docs_list):
263
- logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
264
- self.embeddings = np.array([])
265
- return
266
-
267
- self.embeddings = np.vstack(embedded_docs_list)
268
- logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
269
- except ValueError as ve:
270
- logging.error(f"Error stacking embeddings: {ve}")
271
- self.embeddings = np.array([])
272
-
273
- self.is_initialized = True
274
-
275
- def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
276
- # Ensure inputs are numpy arrays with proper dtype
277
- embeddings_matrix = np.asarray(embeddings_matrix, dtype=np.float32)
278
- query_vector = np.asarray(query_vector, dtype=np.float32)
279
-
280
- if embeddings_matrix.ndim == 1:
281
- embeddings_matrix = embeddings_matrix.reshape(1, -1)
282
- if query_vector.ndim == 1:
283
- query_vector = query_vector.reshape(1, -1)
284
-
285
- if embeddings_matrix.size == 0 or query_vector.size == 0:
286
- return np.array([])
287
-
288
- norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
289
- normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
290
-
291
- norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
292
- normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
293
-
294
- return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()
295
-
296
- async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
297
- if not self.is_initialized:
298
- logging.debug("RAG system not initialized. Cannot retrieve info.")
299
- return ""
300
- if self.embeddings is None or self.embeddings.size == 0:
301
- logging.debug("RAG embeddings not available. Cannot retrieve info.")
302
- return ""
303
- if not query or not isinstance(query, str):
304
- logging.debug("Empty or invalid query for RAG retrieval.")
305
- return ""
306
-
307
- if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
308
- logging.error("GenAI client not available for RAG query embedding.")
309
- return ""
310
-
311
- try:
312
- query_vector = await asyncio.to_thread(self._embed_single_document_sync, query)
313
- if query_vector is None or query_vector.size == 0:
314
- logging.warning("Query vector embedding failed or is empty for RAG.")
315
- return ""
316
-
317
- similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
318
- if similarity_scores.size == 0:
319
- return ""
320
-
321
- relevant_indices = np.where(similarity_scores >= min_similarity)[0]
322
- if len(relevant_indices) == 0:
323
- logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
324
- return ""
325
-
326
- relevant_scores = similarity_scores[relevant_indices]
327
- sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
328
-
329
- top_indices = sorted_relevant_indices_of_original[:top_k]
330
-
331
- context_parts = []
332
- if 'text' in self.documents_df.columns:
333
- for i in top_indices:
334
- if 0 <= i < len(self.documents_df):
335
- context_parts.append(self.documents_df.iloc[i]['text'])
336
-
337
- context = "\n\n---\n\n".join(context_parts)
338
- logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'")
339
- return context
340
-
341
- except Exception as e:
342
- logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
343
- return ""
344
 
345
  class EmployerBrandingAgent:
346
  def __init__(self,
347
  all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
348
- rag_documents_df: Optional[pd.DataFrame] = None,
349
  llm_model_name: str = LLM_MODEL_NAME,
350
  embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
351
  generation_config_dict: Optional[Dict] = None,
@@ -353,9 +163,6 @@ class EmployerBrandingAgent:
353
 
354
  self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()}
355
 
356
- _rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
357
- self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
358
-
359
  self.llm_model_name = llm_model_name
360
  self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
361
  self.safety_settings_list = safety_settings_list or DEFAULT_SAFETY_SETTINGS
@@ -371,8 +178,6 @@ class EmployerBrandingAgent:
371
  self.pandas_agent = None
372
  self._initialize_pandas_agent()
373
 
374
- logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
375
-
376
  def _initialize_pandas_agent(self):
377
  """Initialize PandasAI with enhanced configuration for chart generation"""
378
  if not self.all_dataframes or not GEMINI_API_KEY:
@@ -475,8 +280,7 @@ class EmployerBrandingAgent:
475
  if not client: # Fix: Remove reference to llm_model_instance
476
  logging.error("Cannot initialize agent: GenAI client not available/configured.")
477
  return False
478
-
479
- await self.rag_system.initialize_embeddings()
480
 
481
  # Verify PandasAI agent is ready
482
  pandas_ready = self.pandas_agent is not None
@@ -485,8 +289,6 @@ class EmployerBrandingAgent:
485
  self._initialize_pandas_agent()
486
  pandas_ready = self.pandas_agent is not None
487
 
488
- self.is_ready = self.rag_system.is_initialized and pandas_ready
489
- logging.info(f"EnhancedEmployerBrandingAgent.initialize completed. RAG: {self.rag_system.is_initialized}, PandasAI: {pandas_ready}, Agent ready: {self.is_ready}")
490
  return self.is_ready
491
 
492
  except Exception as e:
@@ -814,7 +616,6 @@ class EmployerBrandingAgent:
814
  try:
815
  system_prompt = self._build_system_prompt()
816
  data_summary = self._get_dataframes_summary()
817
- rag_context = await self.rag_system.retrieve_relevant_info(query, top_k=2, min_similarity=0.25)
818
 
819
  # Build enhanced prompt based on query type and available results
820
  if query_type == "data" and pandas_result:
@@ -828,7 +629,6 @@ class EmployerBrandingAgent:
828
  {pandas_result}
829
 
830
  ## Additional Knowledge Context:
831
- {rag_context if rag_context else 'No additional context retrieved.'}
832
 
833
  ## User Query:
834
  {query}
@@ -844,7 +644,6 @@ class EmployerBrandingAgent:
844
  {data_summary}
845
 
846
  ## Knowledge Base Context:
847
- {rag_context if rag_context else 'No specific background information retrieved.'}
848
 
849
  ## User Query:
850
  {query}
@@ -1005,14 +804,7 @@ class EmployerBrandingAgent:
1005
 
1006
  # Reinitialize PandasAI agent with new data
1007
  self._initialize_pandas_agent()
1008
-
1009
- # Note: RAG system uses static documents and doesn't need reinitialization
1010
-
1011
- def update_rag_documents(self, new_rag_df: pd.DataFrame):
1012
- """Updates RAG documents and reinitializes embeddings"""
1013
- self.rag_system.documents_df = new_rag_df.copy()
1014
- logging.info(f"RAG documents updated. Count: {len(new_rag_df)}")
1015
- # Note: Embeddings will need to be reinitialized - call initialize() after this
1016
 
1017
  def clear_chat_history(self):
1018
  """Clears the agent's internal chat history"""
@@ -1026,13 +818,10 @@ class EmployerBrandingAgent:
1026
  "has_api_key": bool(GEMINI_API_KEY),
1027
  "genai_available": GENAI_AVAILABLE,
1028
  "client_type": "genai.Client" if client else "None", # Fix: Remove reference to llm_model_instance
1029
- "rag_initialized": self.rag_system.is_initialized,
1030
  "pandas_agent_ready": self.pandas_agent is not None,
1031
  "num_dataframes": len(self.all_dataframes),
1032
  "dataframe_keys": list(self.all_dataframes.keys()),
1033
- "num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
1034
  "llm_model_name": self.llm_model_name,
1035
- "embedding_model_name": self.rag_system.embedding_model_name,
1036
  "chat_history_length": len(self.chat_history),
1037
  "charts_save_path_pandasai": pai.config.save_charts_path if pai.config.llm else "PandasAI not configured"
1038
  }
@@ -1074,11 +863,10 @@ class EmployerBrandingAgent:
1074
  return suggestions[:10] # Limit to top 10 suggestions
1075
 
1076
  # --- Helper Functions for External Integration ---
1077
- def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
1078
- rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent:
1079
  """Factory function to create a new agent instance"""
1080
  logging.info("Creating new EnhancedEmployerBrandingAgent instance via helper function.")
1081
- return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs)
1082
 
1083
  async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool:
1084
  """Async helper to initialize an agent instance"""
 
80
 
81
  DEFAULT_SAFETY_SETTINGS = []
82
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  # --- Client Initialization ---
84
  client = None
85
  if GEMINI_API_KEY and GENAI_AVAILABLE:
 
151
  full_representation.append(get_df_schema_representation(df_instance, name))
152
  return "\n".join(full_representation)
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  class EmployerBrandingAgent:
157
  def __init__(self,
158
  all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
 
159
  llm_model_name: str = LLM_MODEL_NAME,
160
  embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
161
  generation_config_dict: Optional[Dict] = None,
 
163
 
164
  self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()}
165
 
 
 
 
166
  self.llm_model_name = llm_model_name
167
  self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
168
  self.safety_settings_list = safety_settings_list or DEFAULT_SAFETY_SETTINGS
 
178
  self.pandas_agent = None
179
  self._initialize_pandas_agent()
180
 
 
 
181
  def _initialize_pandas_agent(self):
182
  """Initialize PandasAI with enhanced configuration for chart generation"""
183
  if not self.all_dataframes or not GEMINI_API_KEY:
 
280
  if not client: # Fix: Remove reference to llm_model_instance
281
  logging.error("Cannot initialize agent: GenAI client not available/configured.")
282
  return False
283
+
 
284
 
285
  # Verify PandasAI agent is ready
286
  pandas_ready = self.pandas_agent is not None
 
289
  self._initialize_pandas_agent()
290
  pandas_ready = self.pandas_agent is not None
291
 
 
 
292
  return self.is_ready
293
 
294
  except Exception as e:
 
616
  try:
617
  system_prompt = self._build_system_prompt()
618
  data_summary = self._get_dataframes_summary()
 
619
 
620
  # Build enhanced prompt based on query type and available results
621
  if query_type == "data" and pandas_result:
 
629
  {pandas_result}
630
 
631
  ## Additional Knowledge Context:
 
632
 
633
  ## User Query:
634
  {query}
 
644
  {data_summary}
645
 
646
  ## Knowledge Base Context:
 
647
 
648
  ## User Query:
649
  {query}
 
804
 
805
  # Reinitialize PandasAI agent with new data
806
  self._initialize_pandas_agent()
807
+
 
 
 
 
 
 
 
808
 
809
  def clear_chat_history(self):
810
  """Clears the agent's internal chat history"""
 
818
  "has_api_key": bool(GEMINI_API_KEY),
819
  "genai_available": GENAI_AVAILABLE,
820
  "client_type": "genai.Client" if client else "None", # Fix: Remove reference to llm_model_instance
 
821
  "pandas_agent_ready": self.pandas_agent is not None,
822
  "num_dataframes": len(self.all_dataframes),
823
  "dataframe_keys": list(self.all_dataframes.keys()),
 
824
  "llm_model_name": self.llm_model_name,
 
825
  "chat_history_length": len(self.chat_history),
826
  "charts_save_path_pandasai": pai.config.save_charts_path if pai.config.llm else "PandasAI not configured"
827
  }
 
863
  return suggestions[:10] # Limit to top 10 suggestions
864
 
865
  # --- Helper Functions for External Integration ---
866
+ def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None) -> EmployerBrandingAgent:
 
867
  """Factory function to create a new agent instance"""
868
  logging.info("Creating new EnhancedEmployerBrandingAgent instance via helper function.")
869
+ return EmployerBrandingAgent(all_dataframes=dataframes)
870
 
871
  async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool:
872
  """Async helper to initialize an agent instance"""