Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +524 -408
eb_agent_module.py
CHANGED
@@ -4,529 +4,645 @@ import os
|
|
4 |
import asyncio
|
5 |
import logging
|
6 |
import numpy as np
|
7 |
-
import textwrap
|
8 |
-
from datetime import datetime #
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
try:
|
11 |
from google import genai
|
12 |
-
from google.genai import types #
|
|
|
|
|
|
|
|
|
|
|
13 |
except ImportError:
|
14 |
-
logging.
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
SafetySetting = None
|
21 |
-
|
22 |
-
|
|
|
23 |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
24 |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
25 |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
26 |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
27 |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
28 |
-
|
|
|
29 |
BLOCK_NONE = "BLOCK_NONE"
|
30 |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
31 |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
32 |
-
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
# --- Custom Exceptions ---
|
35 |
class ValidationError(Exception):
|
36 |
"""Custom validation error for agent inputs"""
|
37 |
pass
|
38 |
|
39 |
-
class RateLimitError(Exception):
|
40 |
"""Placeholder for rate limit errors."""
|
41 |
pass
|
42 |
|
|
|
|
|
|
|
|
|
43 |
# --- Configuration Constants ---
|
44 |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
LLM_MODEL_NAME = "gemini-1.5-flash-latest"
|
49 |
-
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004"
|
50 |
|
51 |
GENERATION_CONFIG_PARAMS = {
|
52 |
"temperature": 0.7,
|
53 |
"top_p": 0.95,
|
54 |
"top_k": 40,
|
55 |
-
"max_output_tokens": 8192,
|
56 |
"candidate_count": 1,
|
57 |
}
|
58 |
|
59 |
-
#
|
60 |
-
DEFAULT_SAFETY_SETTINGS = []
|
61 |
-
logging.info("Default safety settings are now empty (no explicit client-side safety settings).")
|
62 |
-
|
63 |
|
64 |
-
|
|
|
65 |
'text': [
|
66 |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
|
67 |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
|
68 |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
|
69 |
-
"Analyzing follower demographics and post engagement helps refine employer branding strategies."
|
|
|
|
|
70 |
]
|
71 |
})
|
72 |
|
73 |
# --- Client Initialization ---
|
74 |
client = None
|
75 |
-
if GEMINI_API_KEY and
|
76 |
try:
|
|
|
77 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
78 |
-
logging.info("Google GenAI client initialized successfully.")
|
79 |
except Exception as e:
|
80 |
-
logging.error(f"Failed to initialize Google GenAI client: {e}"
|
|
|
81 |
else:
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
class AdvancedRAGSystem:
|
86 |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
|
87 |
-
self.documents_df = documents_df.copy()
|
88 |
-
|
89 |
-
self.
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if not client:
|
94 |
raise ConnectionError("GenAI client not initialized for RAG embedding.")
|
95 |
if not text or not isinstance(text, str):
|
96 |
-
|
|
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
async def initialize_embeddings(self):
|
110 |
-
if self.documents_df.empty:
|
111 |
-
logging.
|
112 |
self.embeddings = np.array([])
|
|
|
113 |
return
|
114 |
-
|
|
|
115 |
logging.error("GenAI client not available for RAG embedding initialization.")
|
116 |
self.embeddings = np.array([])
|
117 |
return
|
118 |
|
119 |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
|
120 |
embedded_docs_list = []
|
|
|
121 |
for index, row in self.documents_df.iterrows():
|
122 |
-
text_to_embed = row.get('text')
|
123 |
if not text_to_embed or not isinstance(text_to_embed, str):
|
124 |
-
logging.warning(f"Skipping document at index {index} due to invalid text
|
125 |
continue
|
|
|
126 |
try:
|
|
|
127 |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
|
128 |
-
|
|
|
|
|
|
|
129 |
except Exception as e:
|
130 |
-
logging.error(f"Error embedding document
|
|
|
131 |
|
132 |
if not embedded_docs_list:
|
133 |
self.embeddings = np.array([])
|
134 |
-
logging.warning("No documents were successfully embedded
|
135 |
else:
|
136 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
self.embeddings = np.vstack(embedded_docs_list)
|
138 |
-
logging.info(f"Successfully embedded {len(embedded_docs_list)} documents
|
139 |
except ValueError as ve:
|
140 |
-
logging.error(f"Error stacking embeddings: {ve}
|
141 |
self.embeddings = np.array([])
|
|
|
|
|
|
|
142 |
|
143 |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
|
152 |
-
if
|
153 |
-
logging.debug("RAG system not initialized
|
|
|
|
|
|
|
154 |
return ""
|
155 |
if not query or not isinstance(query, str):
|
156 |
logging.debug("Empty or invalid query for RAG retrieval.")
|
157 |
return ""
|
158 |
-
|
|
|
159 |
logging.error("GenAI client not available for RAG query embedding.")
|
160 |
return ""
|
161 |
|
162 |
try:
|
163 |
-
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query)
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
|
168 |
-
if query_vector.ndim == 0 or query_vector.size == 0:
|
169 |
-
logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}")
|
170 |
-
return ""
|
171 |
-
|
172 |
-
try:
|
173 |
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
|
174 |
-
if similarity_scores.size == 0:
|
175 |
-
relevant_indices_after_threshold = np.where(similarity_scores >= min_similarity)[0]
|
176 |
-
if len(relevant_indices_after_threshold) == 0:
|
177 |
-
logging.debug(f"No documents met the minimum similarity threshold of {min_similarity} for query: {query[:50]}")
|
178 |
return ""
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
context = "\n\n---\n\n".join(context_parts)
|
185 |
-
logging.debug(f"Retrieved RAG context for query '{
|
186 |
return context
|
|
|
187 |
except Exception as e:
|
188 |
-
logging.error(f"Error during RAG retrieval
|
189 |
return ""
|
190 |
|
191 |
-
|
192 |
class EmployerBrandingAgent:
|
193 |
def __init__(self,
|
194 |
-
all_dataframes:
|
195 |
-
rag_documents_df: pd.DataFrame,
|
196 |
-
llm_model_name: str,
|
197 |
-
embedding_model_name: str,
|
198 |
-
generation_config_dict:
|
199 |
-
safety_settings_list: list
|
200 |
-
|
201 |
-
self.all_dataframes = {k:
|
202 |
-
|
203 |
-
|
|
|
|
|
204 |
self.llm_model_name = llm_model_name
|
205 |
-
self.generation_config_dict = generation_config_dict
|
206 |
-
# If an empty list is passed, it means no specific safety settings are enforced by the client.
|
207 |
-
self.safety_settings_list = safety_settings_list if safety_settings_list is not None else []
|
208 |
-
self.embedding_model_name = embedding_model_name
|
209 |
-
self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
|
210 |
-
self.force_sandbox = force_sandbox
|
211 |
-
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. Safety settings count: {len(self.safety_settings_list)}")
|
212 |
-
|
213 |
-
def _get_date_range(self, df: pd.DataFrame) -> str:
|
214 |
-
for col in df.columns:
|
215 |
-
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
216 |
-
try:
|
217 |
-
min_date = df[col].min()
|
218 |
-
max_date = df[col].max()
|
219 |
-
if pd.notna(min_date) and pd.notna(max_date):
|
220 |
-
return f"{min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}"
|
221 |
-
except Exception: pass
|
222 |
-
return "N/A"
|
223 |
-
|
224 |
-
def _calculate_growth_rate(self, df: pd.DataFrame) -> str:
|
225 |
-
logging.debug("_calculate_growth_rate is a placeholder.")
|
226 |
-
return "Growth rate calculation not implemented."
|
227 |
-
def _analyze_engagement_trends(self, df: pd.DataFrame) -> str:
|
228 |
-
logging.debug("_analyze_engagement_trends is a placeholder.")
|
229 |
-
return "Engagement trend analysis not implemented."
|
230 |
-
def _analyze_demographics(self, df: pd.DataFrame) -> str:
|
231 |
-
logging.debug("_analyze_demographics is a placeholder.")
|
232 |
-
return "Demographic analysis not implemented."
|
233 |
-
def _analyze_post_performance(self, df: pd.DataFrame) -> str:
|
234 |
-
logging.debug("_analyze_post_performance is a placeholder.")
|
235 |
-
return "Post performance analysis not implemented."
|
236 |
-
def _extract_content_themes(self, df: pd.DataFrame) -> str:
|
237 |
-
logging.debug("_extract_content_themes is a placeholder.")
|
238 |
-
return "Content theme extraction not implemented."
|
239 |
-
def _find_optimal_times(self, df: pd.DataFrame) -> str:
|
240 |
-
logging.debug("_find_optimal_times is a placeholder.")
|
241 |
-
return "Optimal posting time analysis not implemented."
|
242 |
-
|
243 |
-
def _calculate_key_metrics(self, df: pd.DataFrame, df_type: str) -> dict:
|
244 |
-
metrics = {}
|
245 |
-
if 'follower' in df_type.lower():
|
246 |
-
metrics.update({'follower_growth_rate': self._calculate_growth_rate(df), 'engagement_trends': self._analyze_engagement_trends(df), 'demographic_distribution': self._analyze_demographics(df)})
|
247 |
-
elif 'post' in df_type.lower():
|
248 |
-
metrics.update({'post_performance': self._analyze_post_performance(df), 'content_themes': self._extract_content_themes(df), 'optimal_posting_times': self._find_optimal_times(df)})
|
249 |
-
elif 'mention' in df_type.lower():
|
250 |
-
metrics['mention_volume_trend'] = "Mention volume trend not implemented."
|
251 |
-
metrics['mention_sentiment_overview'] = "Mention sentiment overview not implemented."
|
252 |
-
if not metrics:
|
253 |
-
logging.debug(f"No specific key metrics defined for df_type: {df_type}")
|
254 |
-
return {"info": "Standard metrics applicable."}
|
255 |
-
return metrics
|
256 |
-
|
257 |
-
def _calculate_data_freshness(self, df: pd.DataFrame) -> str:
|
258 |
-
for col in df.columns:
|
259 |
-
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
260 |
-
try:
|
261 |
-
max_date = df[col].max()
|
262 |
-
if pd.notna(max_date):
|
263 |
-
days_diff = (datetime.now(max_date.tzinfo if max_date.tzinfo else None) - max_date).days
|
264 |
-
return f"Data up to {max_date.strftime('%Y-%m-%d')} ({days_diff} days old)"
|
265 |
-
except Exception: pass
|
266 |
-
return "Freshness N/A (no clear date column)"
|
267 |
-
def _check_data_consistency(self, df: pd.DataFrame) -> str:
|
268 |
-
logging.debug("_check_data_consistency is a placeholder.")
|
269 |
-
return "Consistency checks not implemented."
|
270 |
-
def _identify_accuracy_issues(self, df: pd.DataFrame) -> str:
|
271 |
-
logging.debug("_identify_accuracy_issues is a placeholder.")
|
272 |
-
return "Accuracy issue identification not implemented."
|
273 |
-
|
274 |
-
def _assess_data_quality(self, df: pd.DataFrame) -> dict:
|
275 |
-
completeness = (1 - (df.isnull().sum().sum() / (len(df) * len(df.columns)))) if len(df) > 0 and len(df.columns) > 0 else 0
|
276 |
-
return {'completeness_score': f"{completeness:.2%}", 'freshness_info': self._calculate_data_freshness(df), 'consistency_check': self._check_data_consistency(df), 'accuracy_flags_summary': self._identify_accuracy_issues(df), 'sample_size_notes': f"{len(df)} records. {'Adequate for basic analysis.' if len(df) >= 100 else 'Limited sample size; insights may be indicative.'}"}
|
277 |
-
|
278 |
-
def _identify_patterns(self, df: pd.DataFrame, key: str) -> str:
|
279 |
-
logging.debug(f"_identify_patterns for {key} is a placeholder.")
|
280 |
-
return "Pattern identification not implemented."
|
281 |
-
|
282 |
-
def _format_df_analysis(self, df_key: str, analysis: dict) -> str:
|
283 |
-
formatted_parts = [f"\n--- DataFrame: df_{df_key} ---", f" Shape: {analysis['shape']}", f" Date Range: {analysis['date_range']}", " Key Metrics:"]
|
284 |
-
for metric, value in analysis['key_metrics'].items(): formatted_parts.append(f" - {metric.replace('_', ' ').title()}: {value}")
|
285 |
-
formatted_parts.append(" Data Quality Assessment:")
|
286 |
-
for aspect, value in analysis['data_quality'].items(): formatted_parts.append(f" - {aspect.replace('_', ' ').title()}: {value}")
|
287 |
-
formatted_parts.append(f" Notable Patterns: {analysis['notable_patterns']}")
|
288 |
-
return "\n".join(formatted_parts)
|
289 |
-
|
290 |
-
def _get_enhanced_schemas_representation(self) -> str:
|
291 |
-
schema_descriptions = ["=== DETAILED LINKEDIN DATA OVERVIEW ==="]
|
292 |
-
if not self.all_dataframes:
|
293 |
-
schema_descriptions.append("No dataframes available for analysis.")
|
294 |
-
return "\n".join(schema_descriptions)
|
295 |
-
for key, df in self.all_dataframes.items():
|
296 |
-
if df.empty:
|
297 |
-
schema_descriptions.append(f"\n--- DataFrame: df_{key} ---\nStatus: Empty. No analysis possible.")
|
298 |
-
continue
|
299 |
-
analysis = {'shape': df.shape, 'date_range': self._get_date_range(df), 'key_metrics': self._calculate_key_metrics(df, key), 'data_quality': self._assess_data_quality(df), 'notable_patterns': self._identify_patterns(df, key)}
|
300 |
-
schema_descriptions.append(self._format_df_analysis(key, analysis))
|
301 |
-
return "\n".join(schema_descriptions)
|
302 |
-
|
303 |
-
def _extract_query_intent(self, query: str) -> str:
|
304 |
-
logging.debug("_extract_query_intent is a placeholder.")
|
305 |
-
if "compare" in query.lower() or "benchmark" in query.lower(): return "comparison"
|
306 |
-
if "trend" in query.lower(): return "trend_analysis"
|
307 |
-
return "general"
|
308 |
-
|
309 |
-
async def _get_business_context(self, intent: str) -> str:
|
310 |
-
logging.debug("_get_business_context is a placeholder.")
|
311 |
-
if intent == "comparison": return "Company is focused on outperforming competitors in tech hiring."
|
312 |
-
return "Company aims to improve overall employer brand perception."
|
313 |
-
|
314 |
-
async def _get_industry_benchmarks(self, intent: str) -> str:
|
315 |
-
logging.debug("_get_industry_benchmarks is a placeholder.")
|
316 |
-
if intent == "trend_analysis": return "Typical follower growth in this sector is 5-10% MoM."
|
317 |
-
return "Average engagement rate for similar companies is 2-3%."
|
318 |
-
|
319 |
-
async def _enhance_rag_context(self, query: str, base_context: str) -> str:
|
320 |
-
intent = self._extract_query_intent(query)
|
321 |
-
business_context_val = await self._get_business_context(intent)
|
322 |
-
benchmarks_val = await self._get_industry_benchmarks(intent)
|
323 |
-
enhanced_context = f"""{base_context}
|
324 |
-
--- ADDITIONAL CONTEXT FOR YOUR ANALYSIS ---
|
325 |
-
Business Focus: {business_context_val}
|
326 |
-
Relevant Benchmarks: {benchmarks_val}"""
|
327 |
-
return enhanced_context
|
328 |
-
|
329 |
-
async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
|
330 |
-
prompt_parts = ["You are an expert Employer Branding Analyst...", "--- DETAILED DATA OVERVIEW ---", self.schemas_representation]
|
331 |
-
if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0:
|
332 |
-
base_rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
|
333 |
-
if base_rag_context:
|
334 |
-
enhanced_rag_context = await self._enhance_rag_context(raw_user_query, base_rag_context)
|
335 |
-
prompt_parts.extend(["--- RELEVANT CONTEXTUAL INFORMATION (from documents & business knowledge) ---", enhanced_rag_context])
|
336 |
-
prompt_parts.extend(["--- USER REQUEST ---", f"Based on all the information above, please respond to the following user query:\n{raw_user_query}"])
|
337 |
-
final_prompt = "\n".join(prompt_parts)
|
338 |
-
logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
|
339 |
-
return final_prompt
|
340 |
-
|
341 |
-
async def _process_structured_query(self, prompt: str) -> dict:
|
342 |
-
logging.debug("_process_structured_query is a placeholder.")
|
343 |
-
return {"Key Findings": ["Placeholder finding 1"], "Performance Metrics": ["Placeholder metric"], "Actionable Recommendations": {"Immediate Actions (0-30 days)": ["Placeholder action"]}, "Risk Assessment": ["Placeholder risk"], "Success Metrics to Track": ["Placeholder KPI"]}
|
344 |
-
|
345 |
-
async def _generate_hr_insights(self, query: str, context: str) -> str:
|
346 |
-
insight_prompt = f"As an expert HR analytics consultant...\n{context}\nUser Query: {query}\nPlease provide insights in this structured format:\n## Key Findings\n- ...\n..."
|
347 |
-
if not client: return "Error: AI client not configured for generating HR insights."
|
348 |
-
api_call_contents = [{"role": "user", "parts": [{"text": insight_prompt}]}]
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
if
|
353 |
-
for
|
354 |
try:
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
elif
|
360 |
-
|
|
|
361 |
|
|
|
|
|
|
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
368 |
|
|
|
|
|
369 |
try:
|
370 |
-
|
371 |
-
|
372 |
-
|
|
|
|
|
|
|
|
|
|
|
373 |
except Exception as e:
|
374 |
-
logging.error(f"Error
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
def _validate_query(self, query: str) -> bool:
|
378 |
-
if not query or len(query.strip()) < 3:
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
381 |
return True
|
382 |
|
383 |
-
def
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
# self.
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
api_safety_settings_objects.append(types.SafetySetting(category=ss_item['category'], threshold=ss_item['threshold']))
|
408 |
-
except Exception as e_ss_core:
|
409 |
-
logging.warning(f"Could not create SafetySetting object from {ss_item} in core: {e_ss_core}. Using raw item.")
|
410 |
-
api_safety_settings_objects.append(ss_item)
|
411 |
-
elif self.safety_settings_list : # Fallback if types.SafetySetting not available but list is not empty
|
412 |
-
api_safety_settings_objects = self.safety_settings_list
|
413 |
-
|
414 |
-
|
415 |
-
api_generation_config_obj = None
|
416 |
-
if types and hasattr(types, 'GenerateContentConfig'):
|
417 |
-
api_generation_config_obj = types.GenerateContentConfig(**self.generation_config_dict, safety_settings=api_safety_settings_objects)
|
418 |
-
else: # Fallback if types.GenerateContentConfig is not available
|
419 |
-
logging.error("GenerateContentConfig type not available. API call might fail.")
|
420 |
-
api_generation_config_obj = {**self.generation_config_dict, "safety_settings": api_safety_settings_objects}
|
421 |
-
|
422 |
-
response = await asyncio.to_thread(client.models.generate_content, model=self.llm_model_name, contents=api_call_contents, config=api_generation_config_obj)
|
423 |
-
if not response.candidates:
|
424 |
-
block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
|
425 |
-
block_message = response.prompt_feedback.block_reason_message if response.prompt_feedback else ""
|
426 |
-
error_message = f"The AI's response was blocked. Reason: {block_reason}." + (f" Details: {block_message}" if block_message else "")
|
427 |
-
return error_message
|
428 |
-
return response.text.strip()
|
429 |
-
|
430 |
-
async def _process_query_with_timeout(self, raw_user_query_this_turn: str, timeout_seconds: int = 60) -> str:
|
431 |
-
try: return await asyncio.wait_for(self._core_query_processing(raw_user_query_this_turn), timeout=timeout_seconds)
|
432 |
-
except asyncio.TimeoutError:
|
433 |
-
logging.error(f"Query processing timed out for {timeout_seconds} seconds...")
|
434 |
-
return "I'm sorry, but your request took too long..."
|
435 |
-
|
436 |
-
async def process_query(self, raw_user_query_this_turn: str) -> str:
|
437 |
-
if not client: return "Error: The AI Agent is not available..."
|
438 |
-
if not self._validate_query(raw_user_query_this_turn): return self._get_query_help_message()
|
439 |
-
readiness_check = await self._check_system_readiness()
|
440 |
-
if not readiness_check['ready']: return f"System not ready: {readiness_check['reason']}"
|
441 |
-
max_retries = 2
|
442 |
-
for attempt in range(max_retries + 1):
|
443 |
-
try:
|
444 |
-
response_text = await self._process_query_with_timeout(raw_user_query_this_turn)
|
445 |
-
if "The AI's response was blocked" in response_text: return response_text
|
446 |
-
logging.info(f"Successfully received AI response (attempt {attempt+1}): {response_text[:100]}")
|
447 |
-
return response_text
|
448 |
-
except RateLimitError as rle:
|
449 |
-
if attempt == max_retries: return "The AI service is currently busy..."
|
450 |
-
await asyncio.sleep(2 ** attempt)
|
451 |
-
except ValidationError as ve: return f"Query validation failed: {str(ve)}"
|
452 |
-
except Exception as e:
|
453 |
-
if attempt == max_retries: return self._get_fallback_response(raw_user_query_this_turn)
|
454 |
-
return self._get_fallback_response(raw_user_query_this_turn)
|
455 |
|
456 |
-
def _classify_query_type(self, query: str) -> str:
|
457 |
-
query_lower = query.lower()
|
458 |
-
if any(word in query_lower for word in ['trend', 'growth', 'change', 'time']): return 'trend_analysis'
|
459 |
-
elif any(word in query_lower for word in ['compare', 'benchmark', 'versus']): return 'comparative_analysis'
|
460 |
-
elif any(word in query_lower for word in ['predict', 'forecast', 'future']): return 'predictive_analysis'
|
461 |
-
elif any(word in query_lower for word in ['recommend', 'suggest', 'improve', 'advice', 'help me with']): return 'recommendation_engine'
|
462 |
-
elif any(word in query_lower for word in ['what is', 'explain', 'define']): return 'definition_explanation'
|
463 |
-
else: return 'general_inquiry'
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
def clear_chat_history(self):
|
|
|
467 |
self.chat_history = []
|
468 |
-
logging.info("EmployerBrandingAgent chat history cleared
|
469 |
-
|
470 |
-
def
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
'
|
504 |
-
'
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
"
|
517 |
-
|
518 |
-
|
519 |
-
"
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
|
|
|
|
4 |
import asyncio
|
5 |
import logging
|
6 |
import numpy as np
|
7 |
+
import textwrap # Not used, but kept from original
|
8 |
+
from datetime import datetime # Not used, but kept from original
|
9 |
+
from typing import Dict, List, Optional, Union, Any
|
10 |
+
import traceback
|
11 |
+
|
12 |
+
# Configure logging
|
13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
|
14 |
|
15 |
try:
|
16 |
from google import genai
|
17 |
+
from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc.
|
18 |
+
# If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
|
19 |
+
# For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
|
20 |
+
# and EmbedContentConfig might be implicit or part of task_type.
|
21 |
+
GENAI_AVAILABLE = True
|
22 |
+
logging.info("Google Generative AI library imported successfully.")
|
23 |
except ImportError:
|
24 |
+
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai")
|
25 |
+
GENAI_AVAILABLE = False
|
26 |
+
|
27 |
+
# Dummy classes for graceful degradation (simplified)
|
28 |
+
class genai:
|
29 |
+
Client = None
|
30 |
+
# If using google-generativeai, these would be different:
|
31 |
+
# GenerativeModel = None
|
32 |
+
# def configure(*args, **kwargs): pass
|
33 |
+
# def embed_content(*args, **kwargs): return {}
|
34 |
+
|
35 |
+
class types: # Placeholder for types used in the original code
|
36 |
+
EmbedContentConfig = None # Placeholder
|
37 |
+
GenerationConfig = None # Placeholder
|
38 |
SafetySetting = None
|
39 |
+
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) # Dummy for FinishReason
|
40 |
+
|
41 |
+
class HarmCategory:
|
42 |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
43 |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
44 |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
45 |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
46 |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
47 |
+
|
48 |
+
class HarmBlockThreshold:
|
49 |
BLOCK_NONE = "BLOCK_NONE"
|
50 |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
51 |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
52 |
+
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
53 |
+
|
54 |
+
class generation_types: # Dummy for BlockedPromptException
|
55 |
+
BlockedPromptException = type('BlockedPromptException', (Exception,), {})
|
56 |
+
|
57 |
+
|
58 |
# --- Custom Exceptions ---
|
59 |
class ValidationError(Exception):
|
60 |
"""Custom validation error for agent inputs"""
|
61 |
pass
|
62 |
|
63 |
+
class RateLimitError(Exception): # Not used, but kept
|
64 |
"""Placeholder for rate limit errors."""
|
65 |
pass
|
66 |
|
67 |
+
class AgentNotReadyError(Exception):
|
68 |
+
"""Agent is not properly initialized"""
|
69 |
+
pass
|
70 |
+
|
71 |
# --- Configuration Constants ---
|
72 |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
73 |
+
LLM_MODEL_NAME = "gemini-1.5-flash-latest" # For google-generativeai, model name is directly used.
|
74 |
+
# For client.models.generate_content, it might need "models/gemini-1.5-flash-latest"
|
75 |
+
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Similarly, might need "models/text-embedding-004"
|
|
|
|
|
76 |
|
77 |
GENERATION_CONFIG_PARAMS = {
|
78 |
"temperature": 0.7,
|
79 |
"top_p": 0.95,
|
80 |
"top_k": 40,
|
81 |
+
"max_output_tokens": 8192, # Ensure this is supported
|
82 |
"candidate_count": 1,
|
83 |
}
|
84 |
|
85 |
+
DEFAULT_SAFETY_SETTINGS = [] # User can populate this with {'category': HarmCategory.HARM_CATEGORY_X, 'threshold': HarmBlockThreshold.BLOCK_Y}
|
|
|
|
|
|
|
86 |
|
87 |
+
# Default RAG documents
|
88 |
+
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
|
89 |
'text': [
|
90 |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
|
91 |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
|
92 |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
|
93 |
+
"Analyzing follower demographics and post engagement helps refine employer branding strategies.",
|
94 |
+
"Content strategy should align with company values to attract the right talent.",
|
95 |
+
"Employee advocacy programs can significantly boost employer brand reach and authenticity."
|
96 |
]
|
97 |
})
|
98 |
|
99 |
# --- Client Initialization ---
|
100 |
client = None
|
101 |
+
if GEMINI_API_KEY and GENAI_AVAILABLE:
|
102 |
try:
|
103 |
+
# This is specific. If using google-generativeai, this would be genai.configure(api_key=...)
|
104 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
105 |
+
logging.info("Google GenAI client initialized successfully (using genai.Client).")
|
106 |
except Exception as e:
|
107 |
+
logging.error(f"Failed to initialize Google GenAI client (using genai.Client): {e}")
|
108 |
+
client = None
|
109 |
else:
|
110 |
+
if not GEMINI_API_KEY:
|
111 |
+
logging.warning("GEMINI_API_KEY environment variable not set.")
|
112 |
+
if not GENAI_AVAILABLE:
|
113 |
+
logging.warning("Google GenAI library not available.")
|
114 |
+
|
115 |
+
|
116 |
+
# --- Utility function to get DataFrame schema representation ---
|
117 |
+
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str:
|
118 |
+
"""Generates a string representation of a DataFrame's schema and a small sample."""
|
119 |
+
if not isinstance(df, pd.DataFrame):
|
120 |
+
return f"Item '{df_name}' is not a DataFrame.\n"
|
121 |
+
if df.empty:
|
122 |
+
return f"DataFrame '{df_name}': Empty\n"
|
123 |
+
|
124 |
+
schema_parts = [f"DataFrame '{df_name}':"]
|
125 |
+
schema_parts.append(f" Shape: {df.shape}")
|
126 |
+
schema_parts.append(" Columns:")
|
127 |
+
for col in df.columns:
|
128 |
+
col_type = str(df[col].dtype)
|
129 |
+
null_count = df[col].isnull().sum()
|
130 |
+
unique_count = df[col].nunique()
|
131 |
+
schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})")
|
132 |
+
|
133 |
+
if not df.empty:
|
134 |
+
schema_parts.append(" Sample Data (first 2 rows):")
|
135 |
+
try:
|
136 |
+
sample_df_str = df.head(2).to_string(index=True, max_colwidth=50) # Show index for context
|
137 |
+
indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')])
|
138 |
+
schema_parts.append(indented_sample_df)
|
139 |
+
except Exception as e:
|
140 |
+
schema_parts.append(f" Could not generate sample data: {e}")
|
141 |
+
|
142 |
+
return "\n".join(schema_parts) + "\n"
|
143 |
|
144 |
+
def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str:
|
145 |
+
"""Generates a string representation of all DataFrame schemas."""
|
146 |
+
if not dataframes:
|
147 |
+
return "No DataFrames available to the agent."
|
148 |
+
|
149 |
+
full_representation = ["=== Available DataFrame Schemas for Analysis ==="]
|
150 |
+
for name, df_instance in dataframes.items():
|
151 |
+
full_representation.append(get_df_schema_representation(df_instance, name))
|
152 |
+
return "\n".join(full_representation)
|
153 |
|
154 |
class AdvancedRAGSystem:
|
155 |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
|
156 |
+
self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy()
|
157 |
+
# Ensure 'text' column exists
|
158 |
+
if 'text' not in self.documents_df.columns and not self.documents_df.empty:
|
159 |
+
logging.warning("'text' column not found in RAG documents. RAG might not work.")
|
160 |
+
# Create an empty text column if df is not empty but lacks it, to prevent errors later
|
161 |
+
self.documents_df['text'] = ""
|
162 |
+
|
163 |
+
self.embedding_model_name = embedding_model_name # e.g., "models/text-embedding-004" or just "text-embedding-004"
|
164 |
+
self.embeddings: Optional[np.ndarray] = None
|
165 |
+
self.is_initialized = False
|
166 |
+
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")
|
167 |
+
|
168 |
+
def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]:
|
169 |
if not client:
|
170 |
raise ConnectionError("GenAI client not initialized for RAG embedding.")
|
171 |
if not text or not isinstance(text, str):
|
172 |
+
logging.warning("Cannot embed empty or non-string text for RAG.")
|
173 |
+
return None
|
174 |
|
175 |
+
try:
|
176 |
+
# Standard google-generativeai call:
|
177 |
+
# embedding_response = genai.embed_content(
|
178 |
+
# model=self.embedding_model_name, # e.g., "models/text-embedding-004"
|
179 |
+
# content=text,
|
180 |
+
# task_type="RETRIEVAL_DOCUMENT" # or "SEMANTIC_SIMILARITY"
|
181 |
+
# )
|
182 |
+
# return np.array(embedding_response['embedding'])
|
183 |
+
|
184 |
+
# Using the provided client.models.embed_content structure:
|
185 |
+
# This might require specific types for config.
|
186 |
+
embed_config_payload = None
|
187 |
+
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): # Assuming types.EmbedContentConfig is relevant
|
188 |
+
# The task_type for EmbedContentConfig might differ, e.g., "SEMANTIC_SIMILARITY" or "RETRIEVAL_DOCUMENT"
|
189 |
+
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
190 |
+
|
191 |
+
|
192 |
+
response = client.models.embed_content( # This is the user's original call structure
|
193 |
+
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
|
194 |
+
contents=text, # Original used 'contents', genai.embed_content uses 'content'
|
195 |
+
config=embed_config_payload # Original passed 'config'
|
196 |
+
)
|
197 |
+
|
198 |
+
# Adapt response parsing based on actual client.models.embed_content behavior
|
199 |
+
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
|
200 |
+
# This structure `response.embeddings[0]` seems specific.
|
201 |
+
# Standard genai.embed_content returns a dict `{'embedding': [values]}`
|
202 |
+
return np.array(response.embeddings[0])
|
203 |
+
elif hasattr(response, 'embedding'): # Common for genai.embed_content
|
204 |
+
return np.array(response.embedding)
|
205 |
+
else:
|
206 |
+
logging.error(f"Unexpected embedding response format: {response}")
|
207 |
+
return None
|
208 |
+
except Exception as e:
|
209 |
+
logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True)
|
210 |
+
raise
|
211 |
|
212 |
async def initialize_embeddings(self):
|
213 |
+
if self.documents_df.empty or 'text' not in self.documents_df.columns:
|
214 |
+
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
|
215 |
self.embeddings = np.array([])
|
216 |
+
self.is_initialized = True # Initialized, but with no embeddings
|
217 |
return
|
218 |
+
|
219 |
+
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): # Check if standard genai can be used
|
220 |
logging.error("GenAI client not available for RAG embedding initialization.")
|
221 |
self.embeddings = np.array([])
|
222 |
return
|
223 |
|
224 |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
|
225 |
embedded_docs_list = []
|
226 |
+
|
227 |
for index, row in self.documents_df.iterrows():
|
228 |
+
text_to_embed = row.get('text', '')
|
229 |
if not text_to_embed or not isinstance(text_to_embed, str):
|
230 |
+
logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.")
|
231 |
continue
|
232 |
+
|
233 |
try:
|
234 |
+
# Use asyncio.to_thread for the synchronous embedding call
|
235 |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
|
236 |
+
if embedding_array is not None and embedding_array.size > 0:
|
237 |
+
embedded_docs_list.append(embedding_array)
|
238 |
+
else:
|
239 |
+
logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
|
240 |
except Exception as e:
|
241 |
+
logging.error(f"Error embedding RAG document at index {index}: {e}")
|
242 |
+
continue # Continue with other documents
|
243 |
|
244 |
if not embedded_docs_list:
|
245 |
self.embeddings = np.array([])
|
246 |
+
logging.warning("No RAG documents were successfully embedded.")
|
247 |
else:
|
248 |
try:
|
249 |
+
# Ensure all embeddings have the same shape before vstack
|
250 |
+
first_shape = embedded_docs_list[0].shape
|
251 |
+
if not all(emb.shape == first_shape for emb in embedded_docs_list):
|
252 |
+
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
|
253 |
+
# Attempt to filter out malformed embeddings if possible, or fail
|
254 |
+
# For now, we'll fail stacking if shapes are inconsistent.
|
255 |
+
self.embeddings = np.array([])
|
256 |
+
return # Exit if shapes are inconsistent
|
257 |
+
|
258 |
self.embeddings = np.vstack(embedded_docs_list)
|
259 |
+
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
|
260 |
except ValueError as ve:
|
261 |
+
logging.error(f"Error stacking embeddings (likely due to inconsistent shapes): {ve}")
|
262 |
self.embeddings = np.array([])
|
263 |
+
|
264 |
+
self.is_initialized = True
|
265 |
+
|
266 |
|
267 |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
|
268 |
+
if embeddings_matrix.ndim == 1: # Handle case of single document embedding
|
269 |
+
embeddings_matrix = embeddings_matrix.reshape(1, -1)
|
270 |
+
if query_vector.ndim == 1:
|
271 |
+
query_vector = query_vector.reshape(1, -1)
|
272 |
+
|
273 |
+
if embeddings_matrix.size == 0 or query_vector.size == 0:
|
274 |
+
return np.array([])
|
275 |
+
|
276 |
+
# Normalize embeddings_matrix rows
|
277 |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
|
278 |
+
# Add a small epsilon to avoid division by zero for zero vectors
|
279 |
+
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
|
280 |
+
|
281 |
+
# Normalize query_vector
|
282 |
+
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
|
283 |
+
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
|
284 |
+
|
285 |
+
# Calculate dot product
|
286 |
+
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()
|
287 |
+
|
288 |
|
289 |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
|
290 |
+
if not self.is_initialized:
|
291 |
+
logging.debug("RAG system not initialized. Cannot retrieve info.")
|
292 |
+
return ""
|
293 |
+
if self.embeddings is None or self.embeddings.size == 0:
|
294 |
+
logging.debug("RAG embeddings not available. Cannot retrieve info.")
|
295 |
return ""
|
296 |
if not query or not isinstance(query, str):
|
297 |
logging.debug("Empty or invalid query for RAG retrieval.")
|
298 |
return ""
|
299 |
+
|
300 |
+
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
|
301 |
logging.error("GenAI client not available for RAG query embedding.")
|
302 |
return ""
|
303 |
|
304 |
try:
|
305 |
+
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) # Embed query
|
306 |
+
if query_vector is None or query_vector.size == 0:
|
307 |
+
logging.warning("Query vector embedding failed or is empty for RAG.")
|
308 |
+
return ""
|
309 |
|
|
|
|
|
|
|
|
|
|
|
310 |
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
|
311 |
+
if similarity_scores.size == 0:
|
|
|
|
|
|
|
312 |
return ""
|
313 |
+
|
314 |
+
relevant_indices = np.where(similarity_scores >= min_similarity)[0]
|
315 |
+
if len(relevant_indices) == 0:
|
316 |
+
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
|
317 |
+
return ""
|
318 |
+
|
319 |
+
# Get scores for relevant documents and sort
|
320 |
+
relevant_scores = similarity_scores[relevant_indices]
|
321 |
+
# Argsort returns indices to sort relevant_scores; apply to relevant_indices
|
322 |
+
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
|
323 |
+
|
324 |
+
top_indices = sorted_relevant_indices_of_original[:top_k]
|
325 |
+
|
326 |
+
context_parts = []
|
327 |
+
if 'text' in self.documents_df.columns:
|
328 |
+
for i in top_indices:
|
329 |
+
if 0 <= i < len(self.documents_df):
|
330 |
+
context_parts.append(self.documents_df.iloc[i]['text'])
|
331 |
+
|
332 |
context = "\n\n---\n\n".join(context_parts)
|
333 |
+
logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'")
|
334 |
return context
|
335 |
+
|
336 |
except Exception as e:
|
337 |
+
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
|
338 |
return ""
|
339 |
|
|
|
340 |
class EmployerBrandingAgent:
|
341 |
def __init__(self,
|
342 |
+
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
|
343 |
+
rag_documents_df: Optional[pd.DataFrame] = None,
|
344 |
+
llm_model_name: str = LLM_MODEL_NAME,
|
345 |
+
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
|
346 |
+
generation_config_dict: Optional[Dict] = None,
|
347 |
+
safety_settings_list: Optional[List] = None): # safety_settings_list expects list of dicts or SafetySetting objects
|
348 |
+
|
349 |
+
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} # Deep copy
|
350 |
+
|
351 |
+
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
|
352 |
+
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
|
353 |
+
|
354 |
self.llm_model_name = llm_model_name
|
355 |
+
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
+
# Ensure safety settings are in the correct format if using google-generativeai directly
|
358 |
+
self.safety_settings_list = []
|
359 |
+
if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'):
|
360 |
+
for ss_dict in safety_settings_list:
|
361 |
try:
|
362 |
+
# Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC}
|
363 |
+
self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
|
364 |
+
except Exception as e:
|
365 |
+
logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}")
|
366 |
+
elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models)
|
367 |
+
self.safety_settings_list = safety_settings_list
|
368 |
+
|
369 |
|
370 |
+
self.chat_history: List[Dict[str, str]] = [] # Stores {"role": "user/model", "content": "..."}
|
371 |
+
self.is_ready = False
|
372 |
+
self.llm_model_instance = None # For google-generativeai
|
373 |
|
374 |
+
if GENAI_AVAILABLE and client is None and GEMINI_API_KEY: # If client.Client failed but standard genai can be used
|
375 |
+
try:
|
376 |
+
genai.configure(api_key=GEMINI_API_KEY)
|
377 |
+
self.llm_model_instance = genai.GenerativeModel(self.llm_model_name)
|
378 |
+
logging.info(f"Initialized GenerativeModel '{self.llm_model_name}' via google-generativeai.")
|
379 |
+
except Exception as e:
|
380 |
+
logging.error(f"Failed to initialize google-generativeai.GenerativeModel: {e}")
|
381 |
+
|
382 |
+
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
|
383 |
|
384 |
+
async def initialize(self) -> bool:
|
385 |
+
"""Initializes asynchronous components of the agent, primarily RAG embeddings."""
|
386 |
try:
|
387 |
+
if not client and not self.llm_model_instance : # Check if any LLM access is configured
|
388 |
+
logging.error("Cannot initialize agent: GenAI client (client.Client or google.generativeai) not available/configured.")
|
389 |
+
return False
|
390 |
+
|
391 |
+
await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized
|
392 |
+
self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs)
|
393 |
+
logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}")
|
394 |
+
return True
|
395 |
except Exception as e:
|
396 |
+
logging.error(f"Error during EmployerBrandingAgent.initialize: {e}", exc_info=True)
|
397 |
+
self.is_ready = False
|
398 |
+
return False
|
399 |
+
|
400 |
+
def _get_dataframes_summary(self) -> str:
|
401 |
+
return get_all_schemas_representation(self.all_dataframes)
|
402 |
+
|
403 |
+
def _build_system_prompt(self) -> str:
|
404 |
+
# This prompt provides overall guidance to the LLM.
|
405 |
+
return textwrap.dedent("""
|
406 |
+
You are an expert Employer Branding Analyst AI. Your primary function is to analyze LinkedIn data provided (follower statistics, post performance, mentions) and offer actionable insights, data-driven recommendations, and if requested, Python Pandas code snippets for further analysis.
|
407 |
+
|
408 |
+
When providing insights or recommendations:
|
409 |
+
- Be specific and base your conclusions on the data summaries and context provided.
|
410 |
+
- Structure responses clearly, perhaps using bullet points for key findings or actions.
|
411 |
+
- Focus on practical advice that can help improve employer branding efforts.
|
412 |
+
|
413 |
+
When asked to generate Pandas code:
|
414 |
+
- Assume the data is available in pandas DataFrames named exactly as in the 'Available DataFrame Schemas' section (e.g., `df_follower_stats`, `df_posts`).
|
415 |
+
- Generate executable Python code using pandas.
|
416 |
+
- Ensure the code is directly relevant to the user's query and the available data.
|
417 |
+
- Briefly explain what the code does.
|
418 |
+
- If a query implies data not present in the schemas, state that and do not attempt to fabricate code for it.
|
419 |
+
- Do not generate code that modifies DataFrames in place unless explicitly asked. Prefer returning new DataFrames or Series.
|
420 |
+
- Handle potential errors in data (e.g., missing values if relevant to the operation) gracefully if simple to do so.
|
421 |
+
- Output the code in a single, copy-pasteable block.
|
422 |
+
|
423 |
+
Always refer to the provided DataFrame schemas to understand available columns and data types. Do not hallucinate columns or data.
|
424 |
+
If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
|
425 |
+
""").strip()
|
426 |
+
|
427 |
+
async def _generate_response(self, current_user_query: str) -> str:
|
428 |
+
"""
|
429 |
+
Generates a response from the LLM based on the current query, system prompts,
|
430 |
+
data summaries, RAG context, and the agent's chat history.
|
431 |
+
Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
|
432 |
+
"""
|
433 |
+
if not self.is_ready:
|
434 |
+
return "Agent is not ready. Please initialize."
|
435 |
+
if not client and not self.llm_model_instance:
|
436 |
+
return "Error: AI service is not available. Check API configuration."
|
437 |
+
|
438 |
+
try:
|
439 |
+
system_prompt_text = self._build_system_prompt()
|
440 |
+
data_summary_text = self._get_dataframes_summary()
|
441 |
+
rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25) # Fine-tuned RAG params
|
442 |
+
|
443 |
+
# Construct the messages for the LLM API call
|
444 |
+
# The history (self.chat_history) is set by app.py and includes the current user query.
|
445 |
+
llm_messages = []
|
446 |
+
|
447 |
+
# 1. System-level instructions and context (as a first "user" turn)
|
448 |
+
initial_context_prompt = (
|
449 |
+
f"{system_prompt_text}\n\n"
|
450 |
+
f"## Available Data Overview:\n{data_summary_text}\n\n"
|
451 |
+
f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
|
452 |
+
f"Given this context, please respond to the user queries that follow in the chat history."
|
453 |
+
)
|
454 |
+
llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
|
455 |
+
# 2. Priming assistant message
|
456 |
+
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
|
457 |
+
|
458 |
+
# 3. Append the actual conversation history (already includes the current user query)
|
459 |
+
for entry in self.chat_history: # self.chat_history is set by app.py
|
460 |
+
llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
|
461 |
+
|
462 |
+
# Prepare generation config and safety settings for the API
|
463 |
+
gen_config_payload = self.generation_config_dict
|
464 |
+
safety_settings_payload = self.safety_settings_list # Already formatted if types.SafetySetting used
|
465 |
+
|
466 |
+
if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
|
467 |
+
try:
|
468 |
+
gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
|
469 |
+
except Exception as e:
|
470 |
+
logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
|
471 |
+
|
472 |
+
|
473 |
+
# --- Make the API call ---
|
474 |
+
response_text = ""
|
475 |
+
if self.llm_model_instance: # Standard google-generativeai usage
|
476 |
+
logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
|
477 |
+
api_response = await self.llm_model_instance.generate_content_async(
|
478 |
+
contents=llm_messages,
|
479 |
+
generation_config=gen_config_payload,
|
480 |
+
safety_settings=safety_settings_payload
|
481 |
+
)
|
482 |
+
response_text = api_response.text # Simplification, assumes single part text response
|
483 |
+
elif client: # User's original client.models.generate_content structure
|
484 |
+
logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
|
485 |
+
# This call needs to be async or wrapped, asyncio.to_thread is used as in original
|
486 |
+
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
|
487 |
+
api_response = await asyncio.to_thread(
|
488 |
+
client.models.generate_content,
|
489 |
+
model=model_path,
|
490 |
+
contents=llm_messages,
|
491 |
+
generation_config=gen_config_payload, # Ensure this is the correct type for client.models
|
492 |
+
safety_settings=safety_settings_payload # Ensure this is the correct type
|
493 |
+
)
|
494 |
+
# Parse response from client.models structure
|
495 |
+
if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
|
496 |
+
response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
|
497 |
+
response_text = "".join(response_text_parts).strip()
|
498 |
+
else: # Handle blocked or empty responses from client.models
|
499 |
+
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
|
500 |
+
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
|
501 |
+
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
|
502 |
+
if api_response.candidates and api_response.candidates[0].finish_reason != types.Candidate.FinishReason.STOP: # Assuming types.Candidate.FinishReason.STOP is valid
|
503 |
+
logging.warning(f"Content generation stopped by client.models due to: {api_response.candidates[0].finish_reason}. Safety: {api_response.candidates[0].safety_ratings if hasattr(api_response.candidates[0], 'safety_ratings') else 'N/A'}")
|
504 |
+
return f"I couldn't complete the response. Reason: {api_response.candidates[0].finish_reason}. Please try rephrasing."
|
505 |
+
return "I apologize, but I couldn't generate a response from client.models."
|
506 |
+
|
507 |
+
else:
|
508 |
+
raise ConnectionError("No valid LLM client or model instance available.")
|
509 |
+
|
510 |
+
return response_text.strip()
|
511 |
+
|
512 |
+
except types.generation_types.BlockedPromptException as bpe: # Specific exception for google-generativeai
|
513 |
+
logging.error(f"BlockedPromptException from LLM: {bpe}", exc_info=True)
|
514 |
+
return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {bpe}"
|
515 |
+
except Exception as e:
|
516 |
+
logging.error(f"Error in _generate_response: {e}", exc_info=True)
|
517 |
+
return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
|
518 |
+
|
519 |
|
520 |
def _validate_query(self, query: str) -> bool:
|
521 |
+
if not query or not isinstance(query, str) or len(query.strip()) < 3:
|
522 |
+
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'")
|
523 |
+
return False
|
524 |
+
if len(query) > 3000: # Increased limit slightly
|
525 |
+
logging.warning(f"Invalid query: too long. Length: {len(query)}")
|
526 |
+
return False
|
527 |
return True
|
528 |
|
529 |
+
async def process_query(self, user_query: str) -> str:
|
530 |
+
"""
|
531 |
+
Processes the user's query.
|
532 |
+
It relies on self.chat_history being set externally (by app.py) to include the full
|
533 |
+
conversation context, including the current user_query as the last "user" message.
|
534 |
+
This method then calls _generate_response to get the AI's reply.
|
535 |
+
It does NOT modify self.chat_history itself; app.py is responsible for that based on Gradio state.
|
536 |
+
"""
|
537 |
+
if not self._validate_query(user_query):
|
538 |
+
# This user_query is the one from Gradio input, also the last one in self.chat_history
|
539 |
+
return "Please provide a valid query (3 to 3000 characters)."
|
540 |
+
|
541 |
+
if not self.is_ready:
|
542 |
+
logging.warning("process_query called but agent is not ready. Attempting re-initialization.")
|
543 |
+
# This is a fallback. Ideally, initialize is called once and confirmed.
|
544 |
+
init_success = await self.initialize()
|
545 |
+
if not init_success:
|
546 |
+
return "The agent is not properly initialized and could not be started. Please check configuration and logs."
|
547 |
+
|
548 |
+
# user_query is the current text from the input box.
|
549 |
+
# self.chat_history (set by app.py) should already contain this user_query as the last message.
|
550 |
+
# We pass user_query to _generate_response primarily for RAG context retrieval for the current turn.
|
551 |
+
response_text = await self._generate_response(user_query)
|
552 |
+
return response_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
|
555 |
+
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]):
|
556 |
+
"""Updates the agent's DataFrames. Does not automatically re-initialize RAG or LLM."""
|
557 |
+
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} # Deep copy
|
558 |
+
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}")
|
559 |
+
# Note: If RAG documents depend on these DataFrames, RAG might need re-initialization.
|
560 |
+
# For now, RAG uses a static document set.
|
561 |
|
562 |
def clear_chat_history(self):
|
563 |
+
"""Clears the agent's internal chat history. App.py should also clear Gradio state."""
|
564 |
self.chat_history = []
|
565 |
+
logging.info("EmployerBrandingAgent internal chat history cleared.")
|
566 |
+
|
567 |
+
def get_status(self) -> Dict[str, Any]:
|
568 |
+
return {
|
569 |
+
"is_ready": self.is_ready,
|
570 |
+
"has_api_key": bool(GEMINI_API_KEY),
|
571 |
+
"genai_available": GENAI_AVAILABLE,
|
572 |
+
"client_type": "genai.Client" if client else ("google-generativeai" if self.llm_model_instance else "None"),
|
573 |
+
"rag_initialized": self.rag_system.is_initialized,
|
574 |
+
"num_dataframes": len(self.all_dataframes),
|
575 |
+
"dataframe_keys": list(self.all_dataframes.keys()),
|
576 |
+
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
|
577 |
+
"llm_model_name": self.llm_model_name,
|
578 |
+
"embedding_model_name": self.embedding_model_name
|
579 |
+
}
|
580 |
+
|
581 |
+
# --- Functions for Gradio integration (if needed directly, but app.py handles instantiation) ---
|
582 |
+
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
|
583 |
+
rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent:
|
584 |
+
logging.info("Creating new EmployerBrandingAgent instance via helper function.")
|
585 |
+
return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs)
|
586 |
+
|
587 |
+
async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool:
|
588 |
+
logging.info("Initializing agent via async helper function.")
|
589 |
+
return await agent.initialize()
|
590 |
+
|
591 |
+
|
592 |
+
if __name__ == "__main__":
|
593 |
+
async def test_agent_logic():
|
594 |
+
print("--- Testing Employer Branding Agent ---")
|
595 |
+
if not GEMINI_API_KEY:
|
596 |
+
print("GEMINI_API_KEY not set. Skipping live API tests.")
|
597 |
+
return
|
598 |
+
|
599 |
+
sample_dfs = {
|
600 |
+
"followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}),
|
601 |
+
"posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]})
|
602 |
+
}
|
603 |
+
|
604 |
+
# Test RAG document loading
|
605 |
+
custom_rag = pd.DataFrame({'text': ["Custom RAG context about LinkedIn engagement."]})
|
606 |
+
|
607 |
+
agent = EmployerBrandingAgent(
|
608 |
+
all_dataframes=sample_dfs,
|
609 |
+
rag_documents_df=custom_rag,
|
610 |
+
llm_model_name=LLM_MODEL_NAME,
|
611 |
+
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME
|
612 |
+
)
|
613 |
+
print("Agent Status (pre-init):", agent.get_status())
|
614 |
+
|
615 |
+
init_success = await agent.initialize()
|
616 |
+
print(f"Agent Initialization Success: {init_success}")
|
617 |
+
print("Agent Status (post-init):", agent.get_status())
|
618 |
+
|
619 |
+
if not init_success:
|
620 |
+
print("Agent initialization failed. Cannot proceed with query test.")
|
621 |
+
return
|
622 |
+
|
623 |
+
# Simulate app.py setting history
|
624 |
+
test_query1 = "What are the key columns in my followers data?"
|
625 |
+
agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this
|
626 |
+
|
627 |
+
print(f"\nProcessing Query 1: '{test_query1}'")
|
628 |
+
response1 = await agent.process_query(user_query=test_query1) # Pass current query for RAG etc.
|
629 |
+
print(f"Agent Response 1:\n{response1}")
|
630 |
+
|
631 |
+
# Simulate app.py updating history for next turn
|
632 |
+
agent.chat_history.append({"role": "model", "content": response1})
|
633 |
+
|
634 |
+
test_query2 = "Generate pandas code to get the total follower count."
|
635 |
+
agent.chat_history.append({"role": "user", "content": test_query2})
|
636 |
+
|
637 |
+
print(f"\nProcessing Query 2: '{test_query2}'")
|
638 |
+
response2 = await agent.process_query(user_query=test_query2)
|
639 |
+
print(f"Agent Response 2:\n{response2}")
|
640 |
+
|
641 |
+
agent.chat_history.append({"role": "model", "content": response2})
|
642 |
+
print("\nFinal Agent Chat History (internal):")
|
643 |
+
for item in agent.chat_history:
|
644 |
+
print(f"- {item['role']}: {item['content'][:100]}...")
|
645 |
+
|
646 |
+
print("\n--- Test Complete ---")
|
647 |
|
648 |
+
asyncio.run(test_agent_logic())
|