Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +147 -62
eb_agent_module.py
CHANGED
@@ -177,6 +177,44 @@ class EmployerBrandingAgent:
|
|
177 |
# Initialize PandasAI Agent
|
178 |
self.pandas_agent = None
|
179 |
self._initialize_pandas_agent()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def _initialize_pandas_agent(self):
|
182 |
"""Initialize PandasAI with enhanced configuration for chart generation"""
|
@@ -185,6 +223,7 @@ class EmployerBrandingAgent:
|
|
185 |
return
|
186 |
|
187 |
self._preprocess_dataframes_for_pandas_ai()
|
|
|
188 |
|
189 |
try:
|
190 |
# Configure LiteLLM with Gemini
|
@@ -229,48 +268,37 @@ class EmployerBrandingAgent:
|
|
229 |
self.pandas_dfs = {}
|
230 |
|
231 |
def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
|
232 |
-
"""Enhanced dataframe description
|
233 |
description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
|
234 |
|
235 |
-
# Add
|
236 |
-
column_descriptions = []
|
237 |
-
for col in df.columns:
|
238 |
-
col_lower = col.lower()
|
239 |
-
if 'date' in col_lower:
|
240 |
-
column_descriptions.append(f"'{col}' contains date/time information")
|
241 |
-
elif 'count' in col_lower or 'number' in col_lower:
|
242 |
-
column_descriptions.append(f"'{col}' contains numerical count data")
|
243 |
-
elif 'rate' in col_lower or 'percentage' in col_lower:
|
244 |
-
column_descriptions.append(f"'{col}' contains rate/percentage metrics")
|
245 |
-
elif 'follower' in col_lower:
|
246 |
-
column_descriptions.append(f"'{col}' contains LinkedIn follower data")
|
247 |
-
elif 'engagement' in col_lower:
|
248 |
-
column_descriptions.append(f"'{col}' contains engagement metrics")
|
249 |
-
elif 'post' in col_lower:
|
250 |
-
column_descriptions.append(f"'{col}' contains post-related information")
|
251 |
-
|
252 |
-
if column_descriptions:
|
253 |
-
description_parts.append("Key columns: " + "; ".join(column_descriptions))
|
254 |
-
|
255 |
-
# Enhanced context for specific datasets
|
256 |
if name.lower() in ['follower_stats', 'followers']:
|
257 |
description_parts.append("""
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
For
|
263 |
-
|
264 |
-
|
265 |
-
description_parts.append("""
|
266 |
-
This is a filtered dataset containing only monthly follower gains data.
|
267 |
-
All records have valid dates and are sorted chronologically.
|
268 |
-
Use this for creating time series charts of monthly growth patterns.
|
269 |
""")
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
return " ".join(description_parts)
|
276 |
|
@@ -473,7 +501,7 @@ class EmployerBrandingAgent:
|
|
473 |
# Replace the _generate_pandas_response method and everything after it with this properly indented code:
|
474 |
|
475 |
async def _generate_pandas_response(self, query: str) -> tuple[str, bool]:
|
476 |
-
"""Generate response using PandasAI with enhanced error handling"""
|
477 |
if not self.pandas_agent or not hasattr(self, 'pandas_dfs'):
|
478 |
return "Data analysis not available - PandasAI not initialized.", False
|
479 |
|
@@ -485,31 +513,39 @@ class EmployerBrandingAgent:
|
|
485 |
plt.clf()
|
486 |
plt.close('all')
|
487 |
|
488 |
-
# Enhanced query
|
489 |
-
processed_query = query
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
if 'monthly' in query.lower() and 'follower' in query.lower():
|
494 |
-
processed_query += """.
|
495 |
-
Use the monthly gains data (follower_count_type='follower_gains_monthly')
|
496 |
-
and use the extracted_date or month_name column for the x-axis.
|
497 |
-
Make sure to filter out any null dates and sort by date.
|
498 |
-
Create a clear line chart showing the trend over time."""
|
499 |
-
elif 'cumulative' in query.lower() and 'follower' in query.lower():
|
500 |
-
processed_query += """.
|
501 |
-
Use the cumulative data (follower_count_type='follower_count_cumulative')
|
502 |
-
and create a chart showing the total follower growth over time."""
|
503 |
-
|
504 |
-
# Execute the query
|
505 |
pandas_response = None
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
|
514 |
# Enhanced response processing with better type handling
|
515 |
response_text = ""
|
@@ -606,6 +642,55 @@ class EmployerBrandingAgent:
|
|
606 |
else:
|
607 |
return f"Error processing data query: {str(e)}", False
|
608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|
610 |
"""Generate enhanced response combining PandasAI results with RAG context"""
|
611 |
if not self.is_ready:
|
|
|
177 |
# Initialize PandasAI Agent
|
178 |
self.pandas_agent = None
|
179 |
self._initialize_pandas_agent()
|
180 |
+
|
181 |
+
def _validate_and_log_data(self):
|
182 |
+
"""Validate data quality and log findings"""
|
183 |
+
logging.info("=== DATA VALIDATION REPORT ===")
|
184 |
+
|
185 |
+
for name, df in self.all_dataframes.items():
|
186 |
+
logging.info(f"\nDataFrame: {name}")
|
187 |
+
logging.info(f"Shape: {df.shape}")
|
188 |
+
logging.info(f"Columns: {list(df.columns)}")
|
189 |
+
|
190 |
+
# Check for date columns and their ranges
|
191 |
+
date_cols = [col for col in df.columns if 'date' in col.lower()]
|
192 |
+
for date_col in date_cols:
|
193 |
+
if not df[date_col].empty:
|
194 |
+
try:
|
195 |
+
date_series = pd.to_datetime(df[date_col], errors='coerce')
|
196 |
+
valid_dates = date_series.dropna()
|
197 |
+
if not valid_dates.empty:
|
198 |
+
min_date = valid_dates.min()
|
199 |
+
max_date = valid_dates.max()
|
200 |
+
logging.info(f" {date_col}: {min_date} to {max_date}")
|
201 |
+
|
202 |
+
# Specifically check for 2025 data
|
203 |
+
dates_2025 = valid_dates[valid_dates.dt.year == 2025]
|
204 |
+
if not dates_2025.empty:
|
205 |
+
logging.info(f" Found {len(dates_2025)} records in 2025")
|
206 |
+
except Exception as e:
|
207 |
+
logging.warning(f" Could not parse dates in {date_col}: {e}")
|
208 |
+
|
209 |
+
# Check follower data specifically
|
210 |
+
if 'follower' in name.lower():
|
211 |
+
if 'follower_count_type' in df.columns:
|
212 |
+
type_counts = df['follower_count_type'].value_counts()
|
213 |
+
logging.info(f" Follower count types: {dict(type_counts)}")
|
214 |
+
|
215 |
+
if 'follower_count' in df.columns:
|
216 |
+
follower_stats = df['follower_count'].describe()
|
217 |
+
logging.info(f" Follower count stats: {follower_stats}")
|
218 |
|
219 |
def _initialize_pandas_agent(self):
|
220 |
"""Initialize PandasAI with enhanced configuration for chart generation"""
|
|
|
223 |
return
|
224 |
|
225 |
self._preprocess_dataframes_for_pandas_ai()
|
226 |
+
self._validate_and_log_data()
|
227 |
|
228 |
try:
|
229 |
# Configure LiteLLM with Gemini
|
|
|
268 |
self.pandas_dfs = {}
|
269 |
|
270 |
def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
|
271 |
+
"""Enhanced dataframe description with better data context"""
|
272 |
description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
|
273 |
|
274 |
+
# Add specific context for follower data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
if name.lower() in ['follower_stats', 'followers']:
|
276 |
description_parts.append("""
|
277 |
+
CRITICAL DATA STRUCTURE INFO:
|
278 |
+
- Records with follower_count_type='follower_gains_monthly' contain monthly new follower counts
|
279 |
+
- Records with follower_count_type='follower_count_cumulative' contain total follower counts
|
280 |
+
- The 'extracted_date' column contains properly parsed dates for time analysis
|
281 |
+
- For monthly gains analysis, ALWAYS filter by follower_count_type='follower_gains_monthly'
|
282 |
+
- For growth trends, use extracted_date for chronological ordering
|
283 |
+
- The follower_count column contains the actual numeric values to analyze
|
|
|
|
|
|
|
|
|
284 |
""")
|
285 |
+
|
286 |
+
# Add date range info if available
|
287 |
+
if 'extracted_date' in df.columns:
|
288 |
+
try:
|
289 |
+
date_col = pd.to_datetime(df['extracted_date'], errors='coerce')
|
290 |
+
valid_dates = date_col.dropna()
|
291 |
+
if not valid_dates.empty:
|
292 |
+
min_date = valid_dates.min()
|
293 |
+
max_date = valid_dates.max()
|
294 |
+
description_parts.append(f"Date range: {min_date.strftime('%Y-%m-%d')} to {max_date.strftime('%Y-%m-%d')}")
|
295 |
+
|
296 |
+
# Highlight 2025 data
|
297 |
+
dates_2025 = valid_dates[valid_dates.dt.year >= 2025]
|
298 |
+
if not dates_2025.empty:
|
299 |
+
description_parts.append(f"Contains {len(dates_2025)} records from 2025 onwards")
|
300 |
+
except Exception as e:
|
301 |
+
logging.warning(f"Could not analyze date range: {e}")
|
302 |
|
303 |
return " ".join(description_parts)
|
304 |
|
|
|
501 |
# Replace the _generate_pandas_response method and everything after it with this properly indented code:
|
502 |
|
503 |
async def _generate_pandas_response(self, query: str) -> tuple[str, bool]:
|
504 |
+
"""Generate response using PandasAI with enhanced error handling and data validation"""
|
505 |
if not self.pandas_agent or not hasattr(self, 'pandas_dfs'):
|
506 |
return "Data analysis not available - PandasAI not initialized.", False
|
507 |
|
|
|
513 |
plt.clf()
|
514 |
plt.close('all')
|
515 |
|
516 |
+
# Enhanced query preprocessing
|
517 |
+
processed_query = self._enhance_query_for_pandas(query)
|
518 |
+
logging.info(f"Enhanced query: {processed_query[:200]}...")
|
519 |
+
|
520 |
+
# Execute the query with better error handling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
pandas_response = None
|
522 |
+
try:
|
523 |
+
if len(self.pandas_dfs) == 1:
|
524 |
+
df = list(self.pandas_dfs.values())[0]
|
525 |
+
logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
|
526 |
+
pandas_response = df.chat(processed_query)
|
527 |
+
else:
|
528 |
+
dfs = list(self.pandas_dfs.values())
|
529 |
+
pandas_response = pai.chat(processed_query, *dfs)
|
530 |
+
except Exception as pandas_error:
|
531 |
+
logging.error(f"PandasAI execution error: {pandas_error}")
|
532 |
+
|
533 |
+
# Try a simpler version of the query
|
534 |
+
simple_query = self._simplify_query_for_retry(query)
|
535 |
+
if simple_query != query:
|
536 |
+
logging.info(f"Retrying with simplified query: {simple_query}")
|
537 |
+
try:
|
538 |
+
if len(self.pandas_dfs) == 1:
|
539 |
+
df = list(self.pandas_dfs.values())[0]
|
540 |
+
pandas_response = df.chat(simple_query)
|
541 |
+
else:
|
542 |
+
dfs = list(self.pandas_dfs.values())
|
543 |
+
pandas_response = pai.chat(simple_query, *dfs)
|
544 |
+
except Exception as retry_error:
|
545 |
+
logging.error(f"Retry also failed: {retry_error}")
|
546 |
+
return f"Data analysis failed: {str(pandas_error)}", False
|
547 |
+
else:
|
548 |
+
return f"Data analysis failed: {str(pandas_error)}", False
|
549 |
|
550 |
# Enhanced response processing with better type handling
|
551 |
response_text = ""
|
|
|
642 |
else:
|
643 |
return f"Error processing data query: {str(e)}", False
|
644 |
|
645 |
+
|
646 |
+
def _enhance_query_for_pandas(self, query: str) -> str:
|
647 |
+
"""Enhance query with specific data context and instructions"""
|
648 |
+
enhanced_parts = [query]
|
649 |
+
|
650 |
+
# Add specific instructions for follower queries
|
651 |
+
if 'follower' in query.lower() and ('gain' in query.lower() or 'growth' in query.lower()):
|
652 |
+
enhanced_parts.append("""
|
653 |
+
IMPORTANT INSTRUCTIONS:
|
654 |
+
- Use only data where follower_count_type='follower_gains_monthly' for monthly gains analysis
|
655 |
+
- Filter out any rows where extracted_date is null or NaT
|
656 |
+
- Sort results by extracted_date in ascending order
|
657 |
+
- For 2025 data, make sure to include all months from January 2025 onwards
|
658 |
+
- Use extracted_date for time series and month_name for better chart labels
|
659 |
+
- Sum the follower_count values to get total gains
|
660 |
+
""")
|
661 |
+
|
662 |
+
if 'plot' in query.lower() or 'chart' in query.lower():
|
663 |
+
enhanced_parts.append("""
|
664 |
+
CHART REQUIREMENTS:
|
665 |
+
- Create a clear, well-labeled chart
|
666 |
+
- Use appropriate chart type (line chart for time series, bar chart for comparisons)
|
667 |
+
- Include proper axis labels and title
|
668 |
+
- Format dates nicely on x-axis if applicable
|
669 |
+
- Save the chart and return the path
|
670 |
+
""")
|
671 |
+
|
672 |
+
if '2025' in query:
|
673 |
+
enhanced_parts.append("- Focus specifically on data from 2025 onwards")
|
674 |
+
|
675 |
+
return " ".join(enhanced_parts)
|
676 |
+
|
677 |
+
def _simplify_query_for_retry(self, query: str) -> str:
|
678 |
+
"""Create a simpler version of the query for retry attempts"""
|
679 |
+
# Remove complex requirements and focus on core request
|
680 |
+
simple_patterns = {
|
681 |
+
r'plot.*followers.*per.*month': 'show follower gains by month',
|
682 |
+
r'how many.*followers.*gain.*since.*2025': 'sum follower gains from 2025',
|
683 |
+
r'chart.*growth': 'show follower growth over time',
|
684 |
+
}
|
685 |
+
|
686 |
+
query_lower = query.lower()
|
687 |
+
for pattern, replacement in simple_patterns.items():
|
688 |
+
import re
|
689 |
+
if re.search(pattern, query_lower):
|
690 |
+
return replacement
|
691 |
+
|
692 |
+
return query
|
693 |
+
|
694 |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|
695 |
"""Generate enhanced response combining PandasAI results with RAG context"""
|
696 |
if not self.is_ready:
|