Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +479 -362
eb_agent_module.py
CHANGED
@@ -4,40 +4,37 @@ import os
|
|
4 |
import asyncio
|
5 |
import logging
|
6 |
import numpy as np
|
7 |
-
import textwrap
|
8 |
-
from datetime import datetime
|
9 |
from typing import Dict, List, Optional, Union, Any
|
10 |
import traceback
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
|
14 |
|
15 |
try:
|
16 |
from google import genai
|
17 |
-
from google.genai import types
|
18 |
from google.genai import errors
|
19 |
-
# If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
|
20 |
-
# For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
|
21 |
-
# and EmbedContentConfig might be implicit or part of task_type.
|
22 |
GENAI_AVAILABLE = True
|
23 |
logging.info("Google Generative AI library imported successfully.")
|
24 |
except ImportError:
|
25 |
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai")
|
26 |
GENAI_AVAILABLE = False
|
27 |
|
28 |
-
# Dummy classes for graceful degradation
|
29 |
class genai:
|
30 |
Client = None
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
class types: # Placeholder for types used in the original code
|
37 |
-
EmbedContentConfig = None # Placeholder
|
38 |
-
GenerationConfig = None # Placeholder
|
39 |
SafetySetting = None
|
40 |
-
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})})
|
41 |
|
42 |
class HarmCategory:
|
43 |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
@@ -51,17 +48,13 @@ except ImportError:
|
|
51 |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
52 |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
53 |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
54 |
-
|
55 |
-
class generation_types: # Dummy for BlockedPromptException
|
56 |
-
BlockedPromptException = type('BlockedPromptException', (Exception,), {})
|
57 |
-
|
58 |
|
59 |
# --- Custom Exceptions ---
|
60 |
class ValidationError(Exception):
|
61 |
"""Custom validation error for agent inputs"""
|
62 |
pass
|
63 |
|
64 |
-
class RateLimitError(Exception):
|
65 |
"""Placeholder for rate limit errors."""
|
66 |
pass
|
67 |
|
@@ -71,19 +64,18 @@ class AgentNotReadyError(Exception):
|
|
71 |
|
72 |
# --- Configuration Constants ---
|
73 |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
74 |
-
LLM_MODEL_NAME = "gemini-2.5-flash-preview-05-20"
|
75 |
-
|
76 |
-
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" # Similarly, might need "models/text-embedding-004"
|
77 |
|
78 |
GENERATION_CONFIG_PARAMS = {
|
79 |
"temperature": 0.7,
|
80 |
"top_p": 0.95,
|
81 |
"top_k": 40,
|
82 |
-
"max_output_tokens": 8192,
|
83 |
"candidate_count": 1,
|
84 |
}
|
85 |
|
86 |
-
DEFAULT_SAFETY_SETTINGS = []
|
87 |
|
88 |
# Default RAG documents
|
89 |
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
|
@@ -101,11 +93,10 @@ DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
|
|
101 |
client = None
|
102 |
if GEMINI_API_KEY and GENAI_AVAILABLE:
|
103 |
try:
|
104 |
-
# This is specific. If using google-generativeai, this would be genai.configure(api_key=...)
|
105 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
106 |
-
logging.info("Google GenAI client initialized successfully
|
107 |
except Exception as e:
|
108 |
-
logging.error(f"Failed to initialize Google GenAI client
|
109 |
client = None
|
110 |
else:
|
111 |
if not GEMINI_API_KEY:
|
@@ -113,6 +104,19 @@ else:
|
|
113 |
if not GENAI_AVAILABLE:
|
114 |
logging.warning("Google GenAI library not available.")
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# --- Utility function to get DataFrame schema representation ---
|
118 |
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str:
|
@@ -175,10 +179,9 @@ class AdvancedRAGSystem:
|
|
175 |
# Ensure 'text' column exists
|
176 |
if 'text' not in self.documents_df.columns and not self.documents_df.empty:
|
177 |
logging.warning("'text' column not found in RAG documents. RAG might not work.")
|
178 |
-
# Create an empty text column if df is not empty but lacks it, to prevent errors later
|
179 |
self.documents_df['text'] = ""
|
180 |
|
181 |
-
self.embedding_model_name = embedding_model_name
|
182 |
self.embeddings: Optional[np.ndarray] = None
|
183 |
self.is_initialized = False
|
184 |
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")
|
@@ -194,34 +197,18 @@ class AdvancedRAGSystem:
|
|
194 |
embed_config_payload = None
|
195 |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
|
196 |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
197 |
-
|
198 |
response = client.models.embed_content(
|
199 |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
|
200 |
-
contents=text,
|
201 |
config=embed_config_payload
|
202 |
)
|
203 |
|
204 |
-
#
|
205 |
-
logging.info(f"Embedding response type: {type(response)}")
|
206 |
-
logging.info(f"Response attributes: {dir(response)}")
|
207 |
-
|
208 |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
# Try to extract values
|
214 |
-
if hasattr(embedding_obj, 'values'):
|
215 |
-
logging.info(f"Found 'values' attribute with type: {type(embedding_obj.values)}")
|
216 |
-
return np.array(embedding_obj.values)
|
217 |
-
elif hasattr(embedding_obj, 'embedding'):
|
218 |
-
logging.info(f"Found 'embedding' attribute with type: {type(embedding_obj.embedding)}")
|
219 |
-
return np.array(embedding_obj.embedding)
|
220 |
-
else:
|
221 |
-
logging.error(f"ContentEmbedding object has no 'values' or 'embedding' attribute")
|
222 |
-
logging.error(f"Available attributes: {[attr for attr in dir(embedding_obj) if not attr.startswith('_')]}")
|
223 |
-
return None
|
224 |
-
|
225 |
else:
|
226 |
logging.error(f"Unexpected response structure")
|
227 |
return None
|
@@ -234,10 +221,10 @@ class AdvancedRAGSystem:
|
|
234 |
if self.documents_df.empty or 'text' not in self.documents_df.columns:
|
235 |
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
|
236 |
self.embeddings = np.array([])
|
237 |
-
self.is_initialized = True
|
238 |
return
|
239 |
|
240 |
-
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
|
241 |
logging.error("GenAI client not available for RAG embedding initialization.")
|
242 |
self.embeddings = np.array([])
|
243 |
return
|
@@ -252,7 +239,6 @@ class AdvancedRAGSystem:
|
|
252 |
continue
|
253 |
|
254 |
try:
|
255 |
-
# Use asyncio.to_thread for the synchronous embedding call
|
256 |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
|
257 |
if embedding_array is not None and embedding_array.size > 0:
|
258 |
embedded_docs_list.append(embedding_array)
|
@@ -260,33 +246,29 @@ class AdvancedRAGSystem:
|
|
260 |
logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
|
261 |
except Exception as e:
|
262 |
logging.error(f"Error embedding RAG document at index {index}: {e}")
|
263 |
-
continue
|
264 |
|
265 |
if not embedded_docs_list:
|
266 |
self.embeddings = np.array([])
|
267 |
logging.warning("No RAG documents were successfully embedded.")
|
268 |
else:
|
269 |
try:
|
270 |
-
# Ensure all embeddings have the same shape before vstack
|
271 |
first_shape = embedded_docs_list[0].shape
|
272 |
if not all(emb.shape == first_shape for emb in embedded_docs_list):
|
273 |
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
|
274 |
-
# Attempt to filter out malformed embeddings if possible, or fail
|
275 |
-
# For now, we'll fail stacking if shapes are inconsistent.
|
276 |
self.embeddings = np.array([])
|
277 |
-
return
|
278 |
|
279 |
self.embeddings = np.vstack(embedded_docs_list)
|
280 |
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
|
281 |
except ValueError as ve:
|
282 |
-
logging.error(f"Error stacking embeddings
|
283 |
self.embeddings = np.array([])
|
284 |
|
285 |
self.is_initialized = True
|
286 |
|
287 |
-
|
288 |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
|
289 |
-
if embeddings_matrix.ndim == 1:
|
290 |
embeddings_matrix = embeddings_matrix.reshape(1, -1)
|
291 |
if query_vector.ndim == 1:
|
292 |
query_vector = query_vector.reshape(1, -1)
|
@@ -294,19 +276,14 @@ class AdvancedRAGSystem:
|
|
294 |
if embeddings_matrix.size == 0 or query_vector.size == 0:
|
295 |
return np.array([])
|
296 |
|
297 |
-
# Normalize embeddings_matrix rows
|
298 |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
|
299 |
-
# Add a small epsilon to avoid division by zero for zero vectors
|
300 |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
|
301 |
|
302 |
-
# Normalize query_vector
|
303 |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
|
304 |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
|
305 |
|
306 |
-
# Calculate dot product
|
307 |
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()
|
308 |
|
309 |
-
|
310 |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
|
311 |
if not self.is_initialized:
|
312 |
logging.debug("RAG system not initialized. Cannot retrieve info.")
|
@@ -323,7 +300,7 @@ class AdvancedRAGSystem:
|
|
323 |
return ""
|
324 |
|
325 |
try:
|
326 |
-
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query)
|
327 |
if query_vector is None or query_vector.size == 0:
|
328 |
logging.warning("Query vector embedding failed or is empty for RAG.")
|
329 |
return ""
|
@@ -337,9 +314,7 @@ class AdvancedRAGSystem:
|
|
337 |
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
|
338 |
return ""
|
339 |
|
340 |
-
# Get scores for relevant documents and sort
|
341 |
relevant_scores = similarity_scores[relevant_indices]
|
342 |
-
# Argsort returns indices to sort relevant_scores; apply to relevant_indices
|
343 |
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
|
344 |
|
345 |
top_indices = sorted_relevant_indices_of_original[:top_k]
|
@@ -358,404 +333,546 @@ class AdvancedRAGSystem:
|
|
358 |
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
|
359 |
return ""
|
360 |
|
361 |
-
class
|
362 |
def __init__(self,
|
363 |
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
|
364 |
rag_documents_df: Optional[pd.DataFrame] = None,
|
365 |
llm_model_name: str = LLM_MODEL_NAME,
|
366 |
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
|
367 |
generation_config_dict: Optional[Dict] = None,
|
368 |
-
safety_settings_list: Optional[List] = None):
|
369 |
|
370 |
-
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()}
|
371 |
|
372 |
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
|
373 |
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
|
374 |
|
375 |
self.llm_model_name = llm_model_name
|
376 |
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
-
|
379 |
-
self.safety_settings_list = []
|
380 |
-
if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'):
|
381 |
-
for ss_dict in safety_settings_list:
|
382 |
-
try:
|
383 |
-
# Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC}
|
384 |
-
self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
|
385 |
-
except Exception as e:
|
386 |
-
logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}")
|
387 |
-
elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models)
|
388 |
-
self.safety_settings_list = safety_settings_list
|
389 |
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
async def initialize(self) -> bool:
|
406 |
-
"""Initializes asynchronous components of the agent
|
407 |
try:
|
408 |
-
if not client
|
409 |
-
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
|
412 |
-
await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized
|
413 |
-
self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs)
|
414 |
-
logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}")
|
415 |
-
return True
|
416 |
except Exception as e:
|
417 |
-
logging.error(f"Error during
|
418 |
self.is_ready = False
|
419 |
return False
|
420 |
|
|
|
421 |
def _get_dataframes_summary(self) -> str:
|
422 |
return get_all_schemas_representation(self.all_dataframes)
|
423 |
|
424 |
def _build_system_prompt(self) -> str:
|
425 |
-
"""
|
426 |
-
Builds a comprehensive and user-friendly system prompt for an Employer Branding AI Agent
|
427 |
-
tailored for HR professionals, emphasizing natural conversation and masking technical details.
|
428 |
-
"""
|
429 |
return textwrap.dedent("""
|
430 |
-
You are a friendly and insightful Employer Branding Analyst AI,
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
## Communication Style:
|
440 |
-
- **
|
441 |
-
- **
|
442 |
-
- **
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
##
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
##
|
460 |
-
- **
|
461 |
-
- **
|
462 |
-
- **
|
463 |
-
- **
|
464 |
-
|
465 |
-
|
466 |
-
- **Specific instructions for `follower_stats` DataFrame (if available) - *For your internal understanding and processing only*:**
|
467 |
-
- When the user asks about follower numbers or gains, you'll likely need `follower_stats` for your internal analysis.
|
468 |
-
- Remember that date information (formatted as strings "YYYY-MM-DD") is often in the `category_name` column.
|
469 |
-
- To get monthly follower gains, you'll internally filter where `follower_count_type` is `"follower_gains_monthly"`.
|
470 |
-
- The actual numeric follower count for that period will be in another column (e.g., 'follower_count_organic' or 'follower_count_paid').
|
471 |
-
- *When you need to ask the user for clarification related to this data (e.g., about dates or types of followers), do so using general, HR-friendly questions as per the 'Communication Style' guidelines. For example, instead of mentioning `category_name` or `follower_count_type`, you might ask: "Are you interested in follower numbers for a specific month, or the overall trend for the year?" or "Are we looking at followers gained from our day-to-day content, or from specific promotional activities?"*
|
472 |
-
|
473 |
-
## Response Structure Guidelines:
|
474 |
-
1. **Friendly Opening & Executive Summary**: Start with a brief, friendly acknowledgement, then 2-3 key takeaways in simple terms.
|
475 |
-
2. **Data Insights**: What the numbers tell us (with context and HR relevance).
|
476 |
-
3. **Recommendations**: Specific actions to take, perhaps prioritized by likely impact or ease of implementation.
|
477 |
-
4. **Next Steps / Moving Forward**: Clear, actionable follow-up suggestions, or an invitation for further questions.
|
478 |
-
|
479 |
-
## When You Can't Help Directly:
|
480 |
-
- **Be transparent (but not technical)**: Clearly state what you can and cannot do based on the available data or your capabilities, without blaming the data.
|
481 |
-
- **Offer alternatives**: Suggest related analyses you *can* perform or other ways to approach their question.
|
482 |
-
- **Educate gently**: Explain (in simple terms) why certain analyses might require different types of information if it helps the user understand.
|
483 |
-
- **Guide next steps**: Help users understand how they might be able to get the information they need, if it's outside your current scope.
|
484 |
-
|
485 |
-
## Key Reminders:
|
486 |
-
- **Never fabricate data** or assume information that isn't present in the provided schemas.
|
487 |
-
- **Always validate your internal assumptions** against the available data structure.
|
488 |
-
- **Focus on actionable insights** over merely impressive-sounding metrics.
|
489 |
-
- **Remember your audience**: Explain concepts clearly, assuming no prior analytics expertise.
|
490 |
-
- **Prioritize clarity and usefulness** over technical sophistication in your responses.
|
491 |
-
- **Always prioritize a helpful, human-like interaction.**
|
492 |
-
|
493 |
-
Your ultimate goal is to be a trusted partner, empowering HR professionals to confidently make data-driven employer branding decisions by providing clear, friendly, and actionable insights, regardless of their technical background.
|
494 |
""").strip()
|
495 |
|
496 |
-
|
497 |
-
"""
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
if not self.is_ready:
|
503 |
return "Agent is not ready. Please initialize."
|
504 |
-
if not client
|
505 |
return "Error: AI service is not available. Check API configuration."
|
506 |
-
|
507 |
try:
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
#
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
initial_context_prompt = (
|
517 |
-
f"{system_prompt_text}\n\n"
|
518 |
-
f"## Available Data Overview:\n{data_summary_text}\n\n"
|
519 |
-
f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
|
520 |
-
f"Given this context, please respond to the user queries that follow in the chat history."
|
521 |
-
)
|
522 |
-
llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
|
523 |
-
# 2. Priming assistant message
|
524 |
-
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
|
525 |
-
|
526 |
-
# 3. Append the actual conversation history (already includes the current user query)
|
527 |
-
for entry in self.chat_history:
|
528 |
-
llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
|
529 |
-
|
530 |
-
# --- Make the API call ---
|
531 |
-
response_text = ""
|
532 |
-
if self.llm_model_instance: # Standard google-generativeai usage
|
533 |
-
logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
|
534 |
|
535 |
-
|
536 |
-
|
537 |
-
safety_settings_payload = self.safety_settings_list
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
|
542 |
-
except Exception as e:
|
543 |
-
logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
|
544 |
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
response_text = api_response.text
|
551 |
|
552 |
-
|
553 |
-
|
|
|
|
|
|
|
|
|
554 |
|
555 |
-
|
556 |
-
|
557 |
-
contents = []
|
558 |
-
for msg in llm_messages:
|
559 |
-
if msg["role"] == "user":
|
560 |
-
contents.append(msg["parts"][0]["text"])
|
561 |
-
elif msg["role"] == "model":
|
562 |
-
# For model responses, we might need to handle differently
|
563 |
-
# but for now, let's include them as context
|
564 |
-
contents.append(f"Assistant: {msg['parts'][0]['text']}")
|
565 |
|
566 |
-
|
567 |
-
|
568 |
|
569 |
-
|
570 |
-
|
571 |
-
for key, value in self.generation_config_dict.items():
|
572 |
-
config_dict[key] = value
|
573 |
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
safety_settings.append(types.SafetySetting(
|
582 |
category=ss.get('category'),
|
583 |
threshold=ss.get('threshold')
|
584 |
))
|
585 |
else:
|
586 |
safety_settings.append(ss)
|
587 |
-
|
588 |
-
|
589 |
-
|
|
|
|
|
590 |
config = types.GenerateContentConfig(**config_dict)
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
else:
|
606 |
-
|
607 |
-
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
|
608 |
-
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
|
609 |
-
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
|
610 |
-
if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
|
611 |
-
finish_reason = api_response.candidates[0].finish_reason
|
612 |
-
if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
|
613 |
-
logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
|
614 |
-
return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
|
615 |
-
return "I apologize, but I couldn't generate a response from client.models."
|
616 |
else:
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
621 |
except Exception as e:
|
622 |
error_message = str(e).lower()
|
623 |
|
624 |
-
# Check if it's a blocked prompt error by examining the error message
|
625 |
if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
|
626 |
-
logging.error(f"Blocked prompt
|
627 |
-
return
|
628 |
else:
|
629 |
-
logging.error(f"Error in
|
630 |
-
return f"I encountered an error while processing your request: {
|
631 |
|
632 |
|
633 |
def _validate_query(self, query: str) -> bool:
|
|
|
634 |
if not query or not isinstance(query, str) or len(query.strip()) < 3:
|
635 |
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'")
|
636 |
return False
|
637 |
-
if len(query) > 3000:
|
638 |
logging.warning(f"Invalid query: too long. Length: {len(query)}")
|
639 |
return False
|
640 |
return True
|
641 |
|
642 |
async def process_query(self, user_query: str) -> str:
|
643 |
"""
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
"""
|
650 |
if not self._validate_query(user_query):
|
651 |
-
# This user_query is the one from Gradio input, also the last one in self.chat_history
|
652 |
return "Please provide a valid query (3 to 3000 characters)."
|
653 |
|
654 |
if not self.is_ready:
|
655 |
logging.warning("process_query called but agent is not ready. Attempting re-initialization.")
|
656 |
-
# This is a fallback. Ideally, initialize is called once and confirmed.
|
657 |
init_success = await self.initialize()
|
658 |
if not init_success:
|
659 |
return "The agent is not properly initialized and could not be started. Please check configuration and logs."
|
660 |
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
668 |
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]):
|
669 |
-
"""Updates the agent's DataFrames
|
670 |
-
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()}
|
671 |
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}")
|
672 |
-
|
673 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
|
675 |
def clear_chat_history(self):
|
676 |
-
"""Clears the agent's internal chat history
|
677 |
self.chat_history = []
|
678 |
logging.info("EmployerBrandingAgent internal chat history cleared.")
|
679 |
|
680 |
def get_status(self) -> Dict[str, Any]:
|
|
|
681 |
return {
|
682 |
"is_ready": self.is_ready,
|
683 |
"has_api_key": bool(GEMINI_API_KEY),
|
684 |
"genai_available": GENAI_AVAILABLE,
|
685 |
-
"client_type": "genai.Client" if client else
|
686 |
"rag_initialized": self.rag_system.is_initialized,
|
|
|
687 |
"num_dataframes": len(self.all_dataframes),
|
688 |
"dataframe_keys": list(self.all_dataframes.keys()),
|
689 |
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
|
690 |
"llm_model_name": self.llm_model_name,
|
691 |
-
"embedding_model_name": self.embedding_model_name
|
|
|
692 |
}
|
693 |
|
694 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
695 |
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
|
696 |
-
rag_docs: Optional[pd.DataFrame] = None) ->
|
697 |
-
|
698 |
-
|
|
|
699 |
|
700 |
-
async def initialize_agent_async(agent:
|
|
|
701 |
logging.info("Initializing agent via async helper function.")
|
702 |
return await agent.initialize()
|
703 |
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
return
|
711 |
-
|
712 |
-
sample_dfs = {
|
713 |
-
"followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}),
|
714 |
-
"posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]})
|
715 |
-
}
|
716 |
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
agent = EmployerBrandingAgent(
|
721 |
-
all_dataframes=sample_dfs,
|
722 |
-
rag_documents_df=custom_rag,
|
723 |
-
llm_model_name=LLM_MODEL_NAME,
|
724 |
-
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME
|
725 |
-
)
|
726 |
-
print("Agent Status (pre-init):", agent.get_status())
|
727 |
-
|
728 |
-
init_success = await agent.initialize()
|
729 |
-
print(f"Agent Initialization Success: {init_success}")
|
730 |
-
print("Agent Status (post-init):", agent.get_status())
|
731 |
-
|
732 |
-
if not init_success:
|
733 |
-
print("Agent initialization failed. Cannot proceed with query test.")
|
734 |
-
return
|
735 |
-
|
736 |
-
# Simulate app.py setting history
|
737 |
-
test_query1 = "What are the key columns in my followers data?"
|
738 |
-
agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this
|
739 |
|
740 |
-
|
741 |
-
|
742 |
-
|
|
|
|
|
|
|
743 |
|
744 |
-
|
745 |
-
|
|
|
|
|
|
|
746 |
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
print(f"- {item['role']}: {item['content'][:100]}...")
|
758 |
-
|
759 |
-
print("\n--- Test Complete ---")
|
760 |
-
|
761 |
-
asyncio.run(test_agent_logic())
|
|
|
4 |
import asyncio
|
5 |
import logging
|
6 |
import numpy as np
|
7 |
+
import textwrap
|
8 |
+
from datetime import datetime
|
9 |
from typing import Dict, List, Optional, Union, Any
|
10 |
import traceback
|
11 |
+
from pandasai import Agent, SmartDataframe
|
12 |
+
from pandasai.llm import GoogleGemini
|
13 |
+
from pandasai.responses.response_parser import ResponseParser
|
14 |
+
from pandasai.middlewares.base import BaseMiddleware
|
15 |
|
16 |
# Configure logging
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
|
18 |
|
19 |
try:
|
20 |
from google import genai
|
21 |
+
from google.genai import types
|
22 |
from google.genai import errors
|
|
|
|
|
|
|
23 |
GENAI_AVAILABLE = True
|
24 |
logging.info("Google Generative AI library imported successfully.")
|
25 |
except ImportError:
|
26 |
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai")
|
27 |
GENAI_AVAILABLE = False
|
28 |
|
29 |
+
# Dummy classes for graceful degradation
|
30 |
class genai:
|
31 |
Client = None
|
32 |
+
|
33 |
+
class types:
|
34 |
+
EmbedContentConfig = None
|
35 |
+
GenerationConfig = None
|
|
|
|
|
|
|
|
|
36 |
SafetySetting = None
|
37 |
+
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})})
|
38 |
|
39 |
class HarmCategory:
|
40 |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
|
|
48 |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
49 |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
50 |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
|
|
|
|
|
|
|
|
51 |
|
52 |
# --- Custom Exceptions ---
|
53 |
class ValidationError(Exception):
|
54 |
"""Custom validation error for agent inputs"""
|
55 |
pass
|
56 |
|
57 |
+
class RateLimitError(Exception):
|
58 |
"""Placeholder for rate limit errors."""
|
59 |
pass
|
60 |
|
|
|
64 |
|
65 |
# --- Configuration Constants ---
|
66 |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
67 |
+
LLM_MODEL_NAME = "gemini-2.5-flash-preview-05-20"
|
68 |
+
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
|
|
|
69 |
|
70 |
GENERATION_CONFIG_PARAMS = {
|
71 |
"temperature": 0.7,
|
72 |
"top_p": 0.95,
|
73 |
"top_k": 40,
|
74 |
+
"max_output_tokens": 8192,
|
75 |
"candidate_count": 1,
|
76 |
}
|
77 |
|
78 |
+
DEFAULT_SAFETY_SETTINGS = []
|
79 |
|
80 |
# Default RAG documents
|
81 |
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
|
|
|
93 |
client = None
|
94 |
if GEMINI_API_KEY and GENAI_AVAILABLE:
|
95 |
try:
|
|
|
96 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
97 |
+
logging.info("Google GenAI client initialized successfully.")
|
98 |
except Exception as e:
|
99 |
+
logging.error(f"Failed to initialize Google GenAI client: {e}")
|
100 |
client = None
|
101 |
else:
|
102 |
if not GEMINI_API_KEY:
|
|
|
104 |
if not GENAI_AVAILABLE:
|
105 |
logging.warning("Google GenAI library not available.")
|
106 |
|
107 |
+
# --- Custom PandasAI Middleware for Better Integration ---
|
108 |
+
class EmployerBrandingMiddleware(BaseMiddleware):
|
109 |
+
"""Custom middleware to enhance PandasAI responses with HR context"""
|
110 |
+
|
111 |
+
def run(self, code: str, **kwargs) -> str:
|
112 |
+
"""Add HR-friendly context to generated code"""
|
113 |
+
# Add comments to make code more understandable
|
114 |
+
enhanced_code = f"""
|
115 |
+
# HR Analytics Query Processing
|
116 |
+
# This code analyzes your LinkedIn employer branding data
|
117 |
+
{code}
|
118 |
+
"""
|
119 |
+
return enhanced_code
|
120 |
|
121 |
# --- Utility function to get DataFrame schema representation ---
|
122 |
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str:
|
|
|
179 |
# Ensure 'text' column exists
|
180 |
if 'text' not in self.documents_df.columns and not self.documents_df.empty:
|
181 |
logging.warning("'text' column not found in RAG documents. RAG might not work.")
|
|
|
182 |
self.documents_df['text'] = ""
|
183 |
|
184 |
+
self.embedding_model_name = embedding_model_name
|
185 |
self.embeddings: Optional[np.ndarray] = None
|
186 |
self.is_initialized = False
|
187 |
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")
|
|
|
197 |
embed_config_payload = None
|
198 |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
|
199 |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
200 |
+
|
201 |
response = client.models.embed_content(
|
202 |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
|
203 |
+
contents=text, # Fix: Remove the list wrapper
|
204 |
config=embed_config_payload
|
205 |
)
|
206 |
|
207 |
+
# Fix: Update response parsing - use .embeddings directly (it's a list)
|
|
|
|
|
|
|
208 |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
|
209 |
+
# Fix: Access embedding values directly from the list
|
210 |
+
embedding_values = response.embeddings[0] # This is already the array/list of values
|
211 |
+
return np.array(embedding_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
else:
|
213 |
logging.error(f"Unexpected response structure")
|
214 |
return None
|
|
|
221 |
if self.documents_df.empty or 'text' not in self.documents_df.columns:
|
222 |
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
|
223 |
self.embeddings = np.array([])
|
224 |
+
self.is_initialized = True
|
225 |
return
|
226 |
|
227 |
+
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
|
228 |
logging.error("GenAI client not available for RAG embedding initialization.")
|
229 |
self.embeddings = np.array([])
|
230 |
return
|
|
|
239 |
continue
|
240 |
|
241 |
try:
|
|
|
242 |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
|
243 |
if embedding_array is not None and embedding_array.size > 0:
|
244 |
embedded_docs_list.append(embedding_array)
|
|
|
246 |
logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
|
247 |
except Exception as e:
|
248 |
logging.error(f"Error embedding RAG document at index {index}: {e}")
|
249 |
+
continue
|
250 |
|
251 |
if not embedded_docs_list:
|
252 |
self.embeddings = np.array([])
|
253 |
logging.warning("No RAG documents were successfully embedded.")
|
254 |
else:
|
255 |
try:
|
|
|
256 |
first_shape = embedded_docs_list[0].shape
|
257 |
if not all(emb.shape == first_shape for emb in embedded_docs_list):
|
258 |
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
|
|
|
|
|
259 |
self.embeddings = np.array([])
|
260 |
+
return
|
261 |
|
262 |
self.embeddings = np.vstack(embedded_docs_list)
|
263 |
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
|
264 |
except ValueError as ve:
|
265 |
+
logging.error(f"Error stacking embeddings: {ve}")
|
266 |
self.embeddings = np.array([])
|
267 |
|
268 |
self.is_initialized = True
|
269 |
|
|
|
270 |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
|
271 |
+
if embeddings_matrix.ndim == 1:
|
272 |
embeddings_matrix = embeddings_matrix.reshape(1, -1)
|
273 |
if query_vector.ndim == 1:
|
274 |
query_vector = query_vector.reshape(1, -1)
|
|
|
276 |
if embeddings_matrix.size == 0 or query_vector.size == 0:
|
277 |
return np.array([])
|
278 |
|
|
|
279 |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
|
|
|
280 |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
|
281 |
|
|
|
282 |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
|
283 |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
|
284 |
|
|
|
285 |
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()
|
286 |
|
|
|
287 |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
|
288 |
if not self.is_initialized:
|
289 |
logging.debug("RAG system not initialized. Cannot retrieve info.")
|
|
|
300 |
return ""
|
301 |
|
302 |
try:
|
303 |
+
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query)
|
304 |
if query_vector is None or query_vector.size == 0:
|
305 |
logging.warning("Query vector embedding failed or is empty for RAG.")
|
306 |
return ""
|
|
|
314 |
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
|
315 |
return ""
|
316 |
|
|
|
317 |
relevant_scores = similarity_scores[relevant_indices]
|
|
|
318 |
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
|
319 |
|
320 |
top_indices = sorted_relevant_indices_of_original[:top_k]
|
|
|
333 |
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
|
334 |
return ""
|
335 |
|
336 |
+
class EnhancedEmployerBrandingAgent:
|
337 |
def __init__(self,
|
338 |
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
|
339 |
rag_documents_df: Optional[pd.DataFrame] = None,
|
340 |
llm_model_name: str = LLM_MODEL_NAME,
|
341 |
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
|
342 |
generation_config_dict: Optional[Dict] = None,
|
343 |
+
safety_settings_list: Optional[List] = None):
|
344 |
|
345 |
+
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()}
|
346 |
|
347 |
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
|
348 |
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
|
349 |
|
350 |
self.llm_model_name = llm_model_name
|
351 |
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
|
352 |
+
self.safety_settings_list = safety_settings_list or DEFAULT_SAFETY_SETTINGS
|
353 |
+
|
354 |
+
self.chat_history: List[Dict[str, str]] = []
|
355 |
+
self.is_ready = False
|
356 |
+
|
357 |
+
|
358 |
+
# Initialize PandasAI Agent
|
359 |
+
self.pandas_agent = None
|
360 |
+
self._initialize_pandas_agent()
|
361 |
|
362 |
+
logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
+
def _initialize_pandas_agent(self):
|
365 |
+
"""Initialize PandasAI Agent with enhanced configuration"""
|
366 |
+
if not self.all_dataframes or not GEMINI_API_KEY:
|
367 |
+
logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key")
|
368 |
+
return
|
369 |
|
370 |
+
try:
|
371 |
+
# Convert DataFrames to SmartDataframes with descriptive names
|
372 |
+
smart_dfs = []
|
373 |
+
for name, df in self.all_dataframes.items():
|
374 |
+
# Add metadata to help PandasAI understand the data better
|
375 |
+
df_description = self._generate_dataframe_description(name, df)
|
376 |
+
smart_df = SmartDataframe(
|
377 |
+
df,
|
378 |
+
name=name,
|
379 |
+
description=df_description
|
380 |
+
)
|
381 |
+
smart_dfs.append(smart_df)
|
382 |
+
|
383 |
+
# Configure PandasAI with Gemini
|
384 |
+
pandas_llm = GoogleGemini(
|
385 |
+
api_token=GEMINI_API_KEY,
|
386 |
+
model=self.llm_model_name,
|
387 |
+
temperature=0.7,
|
388 |
+
top_p=0.95,
|
389 |
+
top_k=40,
|
390 |
+
max_output_tokens=4096
|
391 |
+
)
|
392 |
|
393 |
+
# Create agent with enhanced configuration
|
394 |
+
self.pandas_agent = Agent(
|
395 |
+
dfs=smart_dfs,
|
396 |
+
config={
|
397 |
+
"llm": pandas_llm,
|
398 |
+
"verbose": True,
|
399 |
+
"enable_cache": True,
|
400 |
+
"save_charts": True,
|
401 |
+
"save_charts_path": "charts/",
|
402 |
+
"custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"],
|
403 |
+
"middlewares": [EmployerBrandingMiddleware()],
|
404 |
+
"response_parser": ResponseParser,
|
405 |
+
"max_retries": 3,
|
406 |
+
"conversational": True
|
407 |
+
}
|
408 |
+
)
|
409 |
+
|
410 |
+
logging.info(f"PandasAI agent initialized successfully with {len(smart_dfs)} DataFrames")
|
411 |
+
|
412 |
+
except Exception as e:
|
413 |
+
logging.error(f"Failed to initialize PandasAI agent: {e}", exc_info=True)
|
414 |
+
self.pandas_agent = None
|
415 |
+
|
416 |
+
def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
|
417 |
+
"""Generate a descriptive summary for PandasAI to better understand the data"""
|
418 |
+
description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
|
419 |
|
420 |
+
# Add column descriptions based on common patterns
|
421 |
+
column_descriptions = []
|
422 |
+
for col in df.columns:
|
423 |
+
col_lower = col.lower()
|
424 |
+
if 'date' in col_lower:
|
425 |
+
column_descriptions.append(f"'{col}' contains date/time information")
|
426 |
+
elif 'count' in col_lower or 'number' in col_lower:
|
427 |
+
column_descriptions.append(f"'{col}' contains numerical count data")
|
428 |
+
elif 'rate' in col_lower or 'percentage' in col_lower:
|
429 |
+
column_descriptions.append(f"'{col}' contains rate/percentage metrics")
|
430 |
+
elif 'follower' in col_lower:
|
431 |
+
column_descriptions.append(f"'{col}' contains LinkedIn follower data")
|
432 |
+
elif 'engagement' in col_lower:
|
433 |
+
column_descriptions.append(f"'{col}' contains engagement metrics")
|
434 |
+
elif 'post' in col_lower:
|
435 |
+
column_descriptions.append(f"'{col}' contains post-related information")
|
436 |
+
|
437 |
+
if column_descriptions:
|
438 |
+
description_parts.append("Key columns: " + "; ".join(column_descriptions))
|
439 |
+
|
440 |
+
# Add specific context for employer branding
|
441 |
+
if name.lower() in ['follower_stats', 'followers']:
|
442 |
+
description_parts.append("This data tracks LinkedIn company page follower growth and demographics for employer branding analysis.")
|
443 |
+
elif name.lower() in ['posts', 'post_stats']:
|
444 |
+
description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.")
|
445 |
+
elif name.lower() in ['mentions', 'brand_mentions']:
|
446 |
+
description_parts.append("This data tracks brand mentions and sentiment for employer branding reputation analysis.")
|
447 |
+
|
448 |
+
return " ".join(description_parts)
|
449 |
|
450 |
async def initialize(self) -> bool:
|
451 |
+
"""Initializes asynchronous components of the agent"""
|
452 |
try:
|
453 |
+
if not client: # Fix: Remove reference to llm_model_instance
|
454 |
+
logging.error("Cannot initialize agent: GenAI client not available/configured.")
|
455 |
+
return False
|
456 |
+
|
457 |
+
await self.rag_system.initialize_embeddings()
|
458 |
+
|
459 |
+
# Verify PandasAI agent is ready
|
460 |
+
pandas_ready = self.pandas_agent is not None
|
461 |
+
if not pandas_ready:
|
462 |
+
logging.warning("PandasAI agent not initialized, attempting re-initialization")
|
463 |
+
self._initialize_pandas_agent()
|
464 |
+
pandas_ready = self.pandas_agent is not None
|
465 |
+
|
466 |
+
self.is_ready = self.rag_system.is_initialized and pandas_ready
|
467 |
+
logging.info(f"EnhancedEmployerBrandingAgent.initialize completed. RAG: {self.rag_system.is_initialized}, PandasAI: {pandas_ready}, Agent ready: {self.is_ready}")
|
468 |
+
return self.is_ready
|
469 |
|
|
|
|
|
|
|
|
|
470 |
except Exception as e:
|
471 |
+
logging.error(f"Error during EnhancedEmployerBrandingAgent.initialize: {e}", exc_info=True)
|
472 |
self.is_ready = False
|
473 |
return False
|
474 |
|
475 |
+
|
476 |
def _get_dataframes_summary(self) -> str:
|
477 |
return get_all_schemas_representation(self.all_dataframes)
|
478 |
|
479 |
def _build_system_prompt(self) -> str:
|
480 |
+
"""Enhanced system prompt that works with PandasAI integration"""
|
|
|
|
|
|
|
481 |
return textwrap.dedent("""
|
482 |
+
You are a friendly and insightful Employer Branding Analyst AI, working as a dedicated partner for HR professionals to make LinkedIn data analysis accessible, actionable, and easy to understand.
|
483 |
+
|
484 |
+
## Your Enhanced Capabilities:
|
485 |
+
You now have advanced data analysis capabilities through PandasAI integration, allowing you to:
|
486 |
+
- Directly query and analyze DataFrames with natural language
|
487 |
+
- Generate charts and visualizations automatically
|
488 |
+
- Perform complex statistical analysis on LinkedIn employer branding data
|
489 |
+
- Handle multi-DataFrame queries and joins seamlessly
|
490 |
+
|
491 |
+
## Core Responsibilities:
|
492 |
+
1. **Intelligent Data Analysis**: Use your PandasAI integration to answer data questions directly and accurately
|
493 |
+
2. **Business Context Translation**: Convert technical analysis results into HR-friendly insights
|
494 |
+
3. **Actionable Recommendations**: Provide specific, implementable strategies based on data findings
|
495 |
+
4. **Educational Guidance**: Help users understand both the data insights and the LinkedIn analytics concepts
|
496 |
+
|
497 |
## Communication Style:
|
498 |
+
- **Natural and Conversational**: Maintain a warm, supportive tone as a helpful colleague
|
499 |
+
- **HR-Focused Language**: Avoid technical jargon; explain analytics terms in business context
|
500 |
+
- **Context-Rich Responses**: Always explain what metrics mean for employer branding strategy
|
501 |
+
- **Structured Insights**: Use clear formatting with headers, bullets, and logical flow
|
502 |
+
|
503 |
+
## Data Analysis Approach:
|
504 |
+
When users ask data questions, you will:
|
505 |
+
1. **Leverage PandasAI**: Use your integrated data analysis capabilities to query the data directly
|
506 |
+
2. **Interpret Results**: Translate technical findings into business insights
|
507 |
+
3. **Add Context**: Combine data results with your RAG knowledge base for comprehensive answers
|
508 |
+
4. **Provide Recommendations**: Suggest specific actions based on the analysis
|
509 |
+
|
510 |
+
## Response Structure:
|
511 |
+
1. **Executive Summary**: Key findings in business terms
|
512 |
+
2. **Data Insights**: What the analysis reveals (charts/visualizations when helpful)
|
513 |
+
3. **Business Impact**: What this means for employer branding strategy
|
514 |
+
4. **Recommendations**: Specific, prioritized action items
|
515 |
+
5. **Next Steps**: Follow-up suggestions or questions
|
516 |
+
|
517 |
+
## Key Behaviors:
|
518 |
+
- **Data-Driven**: Always ground insights in actual data analysis when possible
|
519 |
+
- **Visual When Helpful**: Suggest or create charts that make data more understandable
|
520 |
+
- **Proactive**: Identify related insights the user might find valuable
|
521 |
+
- **Honest About Limitations**: Clearly state when data doesn't support certain analyses
|
522 |
+
|
523 |
+
Your goal remains to be a trusted partner, but now with powerful data analysis capabilities that enable deeper, more accurate insights for data-driven employer branding decisions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
""").strip()
|
525 |
|
526 |
+
def _classify_query_type(self, query: str) -> str:
|
527 |
+
"""Classify whether query needs data analysis, general advice, or both"""
|
528 |
+
data_keywords = [
|
529 |
+
'show', 'analyze', 'chart', 'graph', 'data', 'numbers', 'count', 'total',
|
530 |
+
'average', 'trend', 'compare', 'statistics', 'performance', 'metrics',
|
531 |
+
'followers', 'engagement', 'posts', 'growth', 'rate', 'percentage'
|
532 |
+
]
|
533 |
+
|
534 |
+
advice_keywords = [
|
535 |
+
'recommend', 'suggest', 'advice', 'strategy', 'improve', 'optimize',
|
536 |
+
'best practice', 'should', 'how to', 'what to do', 'tips'
|
537 |
+
]
|
538 |
+
|
539 |
+
query_lower = query.lower()
|
540 |
+
has_data_request = any(keyword in query_lower for keyword in data_keywords)
|
541 |
+
has_advice_request = any(keyword in query_lower for keyword in advice_keywords)
|
542 |
+
|
543 |
+
if has_data_request and has_advice_request:
|
544 |
+
return "hybrid"
|
545 |
+
elif has_data_request:
|
546 |
+
return "data"
|
547 |
+
elif has_advice_request:
|
548 |
+
return "advice"
|
549 |
+
else:
|
550 |
+
return "general"
|
551 |
+
|
552 |
+
async def _generate_pandas_response(self, query: str) -> tuple[str, bool]:
|
553 |
+
"""Generate response using PandasAI for data queries"""
|
554 |
+
if not self.pandas_agent:
|
555 |
+
return "Data analysis not available - PandasAI agent not initialized.", False
|
556 |
+
|
557 |
+
try:
|
558 |
+
# Use PandasAI to analyze the data
|
559 |
+
logging.info(f"Processing data query with PandasAI: {query[:100]}...")
|
560 |
+
pandas_response = self.pandas_agent.chat(query)
|
561 |
+
|
562 |
+
# Check if response is meaningful
|
563 |
+
if pandas_response and str(pandas_response).strip():
|
564 |
+
return str(pandas_response), True
|
565 |
+
else:
|
566 |
+
return "I couldn't generate a meaningful analysis for this query.", False
|
567 |
+
|
568 |
+
except Exception as e:
|
569 |
+
logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
|
570 |
+
return f"I encountered an error while analyzing the data: {str(e)}", False
|
571 |
+
|
572 |
+
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|
573 |
+
"""Generate enhanced response combining PandasAI results with RAG context"""
|
574 |
if not self.is_ready:
|
575 |
return "Agent is not ready. Please initialize."
|
576 |
+
if not client:
|
577 |
return "Error: AI service is not available. Check API configuration."
|
578 |
+
|
579 |
try:
|
580 |
+
system_prompt = self._build_system_prompt()
|
581 |
+
data_summary = self._get_dataframes_summary()
|
582 |
+
rag_context = await self.rag_system.retrieve_relevant_info(query, top_k=2, min_similarity=0.25)
|
583 |
+
|
584 |
+
# Build enhanced prompt based on query type and available results
|
585 |
+
if query_type == "data" and pandas_result:
|
586 |
+
enhanced_prompt = f"""
|
587 |
+
{system_prompt}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
|
589 |
+
## Data Analysis Context:
|
590 |
+
{data_summary}
|
|
|
591 |
|
592 |
+
## PandasAI Analysis Result:
|
593 |
+
{pandas_result}
|
|
|
|
|
|
|
594 |
|
595 |
+
## Additional Knowledge Context:
|
596 |
+
{rag_context if rag_context else 'No additional context retrieved.'}
|
597 |
+
|
598 |
+
## User Query:
|
599 |
+
{query}
|
|
|
600 |
|
601 |
+
Please interpret the data analysis result above and provide business insights in a friendly, HR-focused manner.
|
602 |
+
Explain what the findings mean for employer branding strategy and provide actionable recommendations.
|
603 |
+
"""
|
604 |
+
else:
|
605 |
+
enhanced_prompt = f"""
|
606 |
+
{system_prompt}
|
607 |
|
608 |
+
## Available Data Context:
|
609 |
+
{data_summary}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
|
611 |
+
## Knowledge Base Context:
|
612 |
+
{rag_context if rag_context else 'No specific background information retrieved.'}
|
613 |
|
614 |
+
## User Query:
|
615 |
+
{query}
|
|
|
|
|
616 |
|
617 |
+
Please provide helpful insights and recommendations for this employer branding query.
|
618 |
+
"""
|
619 |
+
|
620 |
+
# Fix: Use only genai.Client approach - remove all google-generativeai code
|
621 |
+
logging.debug(f"Using genai.Client for enhanced response generation")
|
622 |
+
|
623 |
+
# Prepare config
|
624 |
+
config_dict = dict(self.generation_config_dict) if self.generation_config_dict else {}
|
625 |
+
|
626 |
+
if self.safety_settings_list:
|
627 |
+
safety_settings = []
|
628 |
+
for ss in self.safety_settings_list:
|
629 |
+
if isinstance(ss, dict):
|
630 |
+
if GENAI_AVAILABLE and hasattr(types, 'SafetySetting'):
|
631 |
safety_settings.append(types.SafetySetting(
|
632 |
category=ss.get('category'),
|
633 |
threshold=ss.get('threshold')
|
634 |
))
|
635 |
else:
|
636 |
safety_settings.append(ss)
|
637 |
+
else:
|
638 |
+
safety_settings.append(ss)
|
639 |
+
config_dict['safety_settings'] = safety_settings
|
640 |
+
|
641 |
+
if GENAI_AVAILABLE and hasattr(types, 'GenerateContentConfig'):
|
642 |
config = types.GenerateContentConfig(**config_dict)
|
643 |
+
else:
|
644 |
+
config = config_dict
|
645 |
+
|
646 |
+
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
|
647 |
+
|
648 |
+
# Fix: Use synchronous call wrapped in asyncio.to_thread
|
649 |
+
api_response = await asyncio.to_thread(
|
650 |
+
client.models.generate_content,
|
651 |
+
model=model_path,
|
652 |
+
contents=enhanced_prompt, # Fix: Pass single prompt string instead of complex message format
|
653 |
+
config=config
|
654 |
+
)
|
655 |
+
|
656 |
+
# Fix: Updated response parsing
|
657 |
+
if hasattr(api_response, 'candidates') and api_response.candidates:
|
658 |
+
candidate = api_response.candidates[0]
|
659 |
+
if hasattr(candidate, 'content') and candidate.content:
|
660 |
+
if hasattr(candidate.content, 'parts') and candidate.content.parts:
|
661 |
+
response_text_parts = [part.text for part in candidate.content.parts if hasattr(part, 'text')]
|
662 |
+
response_text = "".join(response_text_parts).strip()
|
663 |
+
else:
|
664 |
+
response_text = str(candidate.content).strip()
|
665 |
else:
|
666 |
+
response_text = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
else:
|
668 |
+
response_text = ""
|
669 |
+
|
670 |
+
if not response_text:
|
671 |
+
# Handle blocked or empty responses
|
672 |
+
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback:
|
673 |
+
if hasattr(api_response.prompt_feedback, 'block_reason') and api_response.prompt_feedback.block_reason:
|
674 |
+
logging.warning(f"Prompt blocked: {api_response.prompt_feedback.block_reason}")
|
675 |
+
return f"I'm sorry, your request was blocked. Please try rephrasing your query."
|
676 |
+
return "I couldn't generate a response. Please try rephrasing your query."
|
677 |
+
|
678 |
+
return response_text
|
679 |
+
|
680 |
except Exception as e:
|
681 |
error_message = str(e).lower()
|
682 |
|
|
|
683 |
if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
|
684 |
+
logging.error(f"Blocked prompt: {e}")
|
685 |
+
return "I'm sorry, your request was blocked by the safety filter. Please rephrase your query."
|
686 |
else:
|
687 |
+
logging.error(f"Error in _generate_enhanced_response: {e}", exc_info=True)
|
688 |
+
return f"I encountered an error while processing your request: {str(e)}"
|
689 |
|
690 |
|
691 |
def _validate_query(self, query: str) -> bool:
|
692 |
+
"""Validate user query input"""
|
693 |
if not query or not isinstance(query, str) or len(query.strip()) < 3:
|
694 |
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'")
|
695 |
return False
|
696 |
+
if len(query) > 3000:
|
697 |
logging.warning(f"Invalid query: too long. Length: {len(query)}")
|
698 |
return False
|
699 |
return True
|
700 |
|
701 |
async def process_query(self, user_query: str) -> str:
|
702 |
"""
|
703 |
+
Main method to process user queries with hybrid approach:
|
704 |
+
1. Classify query type (data/advice/hybrid)
|
705 |
+
2. Use PandasAI for data queries
|
706 |
+
3. Use enhanced LLM for interpretation and advice
|
707 |
+
4. Combine results for comprehensive responses
|
708 |
"""
|
709 |
if not self._validate_query(user_query):
|
|
|
710 |
return "Please provide a valid query (3 to 3000 characters)."
|
711 |
|
712 |
if not self.is_ready:
|
713 |
logging.warning("process_query called but agent is not ready. Attempting re-initialization.")
|
|
|
714 |
init_success = await self.initialize()
|
715 |
if not init_success:
|
716 |
return "The agent is not properly initialized and could not be started. Please check configuration and logs."
|
717 |
|
718 |
+
try:
|
719 |
+
# Classify the query type
|
720 |
+
query_type = self._classify_query_type(user_query)
|
721 |
+
logging.info(f"Query classified as: {query_type}")
|
722 |
+
|
723 |
+
pandas_result = ""
|
724 |
+
pandas_success = False
|
725 |
+
|
726 |
+
# For data-related queries, try PandasAI first
|
727 |
+
if query_type in ["data", "hybrid"] and self.pandas_agent:
|
728 |
+
logging.info("Attempting PandasAI analysis...")
|
729 |
+
pandas_result, pandas_success = await self._generate_pandas_response(user_query)
|
730 |
+
|
731 |
+
if pandas_success:
|
732 |
+
logging.info("PandasAI analysis successful")
|
733 |
+
# For pure data queries with successful analysis, we might return enhanced result
|
734 |
+
if query_type == "data":
|
735 |
+
enhanced_response = await self._generate_enhanced_response(
|
736 |
+
user_query, pandas_result, query_type
|
737 |
+
)
|
738 |
+
return enhanced_response
|
739 |
+
else:
|
740 |
+
logging.warning("PandasAI analysis failed, falling back to general response")
|
741 |
+
|
742 |
+
# For hybrid queries, advice queries, or when PandasAI fails
|
743 |
+
if query_type == "hybrid" and pandas_success:
|
744 |
+
# Combine PandasAI results with enhanced advice
|
745 |
+
enhanced_response = await self._generate_enhanced_response(
|
746 |
+
user_query, pandas_result, query_type
|
747 |
+
)
|
748 |
+
return enhanced_response
|
749 |
+
else:
|
750 |
+
# General advice or fallback response
|
751 |
+
enhanced_response = await self._generate_enhanced_response(
|
752 |
+
user_query, "", query_type
|
753 |
+
)
|
754 |
+
return enhanced_response
|
755 |
+
|
756 |
+
except Exception as e:
|
757 |
+
logging.error(f"Error in process_query: {e}", exc_info=True)
|
758 |
+
return f"I encountered an error while processing your request: {str(e)}"
|
759 |
|
760 |
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]):
|
761 |
+
"""Updates the agent's DataFrames and reinitializes PandasAI agent"""
|
762 |
+
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()}
|
763 |
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}")
|
764 |
+
|
765 |
+
# Reinitialize PandasAI agent with new data
|
766 |
+
self._initialize_pandas_agent()
|
767 |
+
|
768 |
+
# Note: RAG system uses static documents and doesn't need reinitialization
|
769 |
+
|
770 |
+
def update_rag_documents(self, new_rag_df: pd.DataFrame):
|
771 |
+
"""Updates RAG documents and reinitializes embeddings"""
|
772 |
+
self.rag_system.documents_df = new_rag_df.copy()
|
773 |
+
logging.info(f"RAG documents updated. Count: {len(new_rag_df)}")
|
774 |
+
# Note: Embeddings will need to be reinitialized - call initialize() after this
|
775 |
|
776 |
def clear_chat_history(self):
|
777 |
+
"""Clears the agent's internal chat history"""
|
778 |
self.chat_history = []
|
779 |
logging.info("EmployerBrandingAgent internal chat history cleared.")
|
780 |
|
781 |
def get_status(self) -> Dict[str, Any]:
|
782 |
+
"""Returns comprehensive status information about the agent"""
|
783 |
return {
|
784 |
"is_ready": self.is_ready,
|
785 |
"has_api_key": bool(GEMINI_API_KEY),
|
786 |
"genai_available": GENAI_AVAILABLE,
|
787 |
+
"client_type": "genai.Client" if client else "None", # Fix: Remove reference to llm_model_instance
|
788 |
"rag_initialized": self.rag_system.is_initialized,
|
789 |
+
"pandas_agent_ready": self.pandas_agent is not None,
|
790 |
"num_dataframes": len(self.all_dataframes),
|
791 |
"dataframe_keys": list(self.all_dataframes.keys()),
|
792 |
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
|
793 |
"llm_model_name": self.llm_model_name,
|
794 |
+
"embedding_model_name": self.rag_system.embedding_model_name,
|
795 |
+
"chat_history_length": len(self.chat_history)
|
796 |
}
|
797 |
|
798 |
+
def get_available_analyses(self) -> List[str]:
|
799 |
+
"""Returns list of suggested analyses based on available data"""
|
800 |
+
if not self.all_dataframes:
|
801 |
+
return ["No data available for analysis"]
|
802 |
+
|
803 |
+
suggestions = []
|
804 |
+
for df_name, df in self.all_dataframes.items():
|
805 |
+
if 'follower' in df_name.lower():
|
806 |
+
suggestions.extend([
|
807 |
+
f"Show follower growth trends from {df_name}",
|
808 |
+
f"Analyze follower demographics in {df_name}",
|
809 |
+
f"Compare follower engagement rates"
|
810 |
+
])
|
811 |
+
elif 'post' in df_name.lower():
|
812 |
+
suggestions.extend([
|
813 |
+
f"Analyze post performance metrics from {df_name}",
|
814 |
+
f"Show best performing content types",
|
815 |
+
f"Compare engagement across post categories"
|
816 |
+
])
|
817 |
+
elif 'mention' in df_name.lower():
|
818 |
+
suggestions.extend([
|
819 |
+
f"Analyze brand mention sentiment from {df_name}",
|
820 |
+
f"Show mention volume trends",
|
821 |
+
f"Identify top mention sources"
|
822 |
+
])
|
823 |
+
|
824 |
+
# Add general suggestions
|
825 |
+
suggestions.extend([
|
826 |
+
"What are the key employer branding trends?",
|
827 |
+
"How can I improve our LinkedIn presence?",
|
828 |
+
"What content strategy should we adopt?",
|
829 |
+
"How do we measure employer branding success?"
|
830 |
+
])
|
831 |
+
|
832 |
+
return suggestions[:10] # Limit to top 10 suggestions
|
833 |
+
|
834 |
+
# --- Helper Functions for External Integration ---
|
835 |
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
|
836 |
+
rag_docs: Optional[pd.DataFrame] = None) -> EnhancedEmployerBrandingAgent:
|
837 |
+
"""Factory function to create a new agent instance"""
|
838 |
+
logging.info("Creating new EnhancedEmployerBrandingAgent instance via helper function.")
|
839 |
+
return EnhancedEmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs)
|
840 |
|
841 |
+
async def initialize_agent_async(agent: EnhancedEmployerBrandingAgent) -> bool:
|
842 |
+
"""Async helper to initialize an agent instance"""
|
843 |
logging.info("Initializing agent via async helper function.")
|
844 |
return await agent.initialize()
|
845 |
|
846 |
+
def validate_dataframes(dataframes: Dict[str, pd.DataFrame]) -> Dict[str, List[str]]:
|
847 |
+
"""Validate dataframes for common issues and return validation report"""
|
848 |
+
validation_report = {}
|
849 |
+
|
850 |
+
for name, df in dataframes.items():
|
851 |
+
issues = []
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
|
853 |
+
if df.empty:
|
854 |
+
issues.append("DataFrame is empty")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
855 |
|
856 |
+
# Check for required columns based on data type
|
857 |
+
if 'follower' in name.lower():
|
858 |
+
required_cols = ['date', 'follower_count']
|
859 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
860 |
+
if missing_cols:
|
861 |
+
issues.append(f"Missing expected columns for follower data: {missing_cols}")
|
862 |
|
863 |
+
elif 'post' in name.lower():
|
864 |
+
required_cols = ['date', 'engagement']
|
865 |
+
missing_cols = [col for col in required_cols if col not in df.columns]
|
866 |
+
if missing_cols:
|
867 |
+
issues.append(f"Missing expected columns for post data: {missing_cols}")
|
868 |
|
869 |
+
# Check for data quality issues
|
870 |
+
if not df.empty:
|
871 |
+
null_percentages = (df.isnull().sum() / len(df) * 100).round(2)
|
872 |
+
high_null_cols = null_percentages[null_percentages > 50].index.tolist()
|
873 |
+
if high_null_cols:
|
874 |
+
issues.append(f"Columns with >50% null values: {high_null_cols}")
|
875 |
+
|
876 |
+
validation_report[name] = issues
|
877 |
+
|
878 |
+
return validation_report
|
|
|
|
|
|
|
|
|
|