Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +184 -87
eb_agent_module.py
CHANGED
@@ -355,41 +355,48 @@ class EmployerBrandingAgent:
|
|
355 |
logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
|
356 |
|
357 |
def _initialize_pandas_agent(self):
|
358 |
-
"""Initialize PandasAI with enhanced configuration"""
|
359 |
if not self.all_dataframes or not GEMINI_API_KEY:
|
360 |
logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key")
|
361 |
return
|
362 |
|
363 |
self._preprocess_dataframes_for_pandas_ai()
|
364 |
-
|
365 |
try:
|
366 |
# Configure LiteLLM with Gemini
|
367 |
llm = LiteLLM(
|
368 |
-
model="gemini/gemini-2.5-flash-preview-05-20",
|
369 |
api_key=GEMINI_API_KEY
|
370 |
)
|
371 |
|
372 |
-
#
|
373 |
pai.config.set({
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
|
|
|
|
|
|
|
|
382 |
})
|
383 |
|
384 |
-
# Store dataframes for chat queries
|
385 |
self.pandas_dfs = {}
|
386 |
for name, df in self.all_dataframes.items():
|
387 |
-
#
|
|
|
|
|
|
|
388 |
df_description = self._generate_dataframe_description(name, df)
|
389 |
pandas_df = pai.DataFrame(df, description=df_description)
|
390 |
self.pandas_dfs[name] = pandas_df
|
391 |
|
392 |
-
self.pandas_agent = True
|
393 |
logging.info(f"PandasAI initialized successfully with {len(self.pandas_dfs)} DataFrames")
|
394 |
|
395 |
except Exception as e:
|
@@ -398,7 +405,7 @@ class EmployerBrandingAgent:
|
|
398 |
self.pandas_dfs = {}
|
399 |
|
400 |
def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
|
401 |
-
"""
|
402 |
description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
|
403 |
|
404 |
# Add column descriptions based on common patterns
|
@@ -421,12 +428,21 @@ class EmployerBrandingAgent:
|
|
421 |
if column_descriptions:
|
422 |
description_parts.append("Key columns: " + "; ".join(column_descriptions))
|
423 |
|
424 |
-
#
|
425 |
-
# Special handling for follower_stats
|
426 |
if name.lower() in ['follower_stats', 'followers']:
|
427 |
-
description_parts.append("
|
428 |
-
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
elif name.lower() in ['posts', 'post_stats']:
|
431 |
description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.")
|
432 |
elif name.lower() in ['mentions', 'brand_mentions']:
|
@@ -464,7 +480,7 @@ class EmployerBrandingAgent:
|
|
464 |
return get_all_schemas_representation(self.all_dataframes)
|
465 |
|
466 |
def _preprocess_dataframes_for_pandas_ai(self):
|
467 |
-
"""
|
468 |
if not self.all_dataframes:
|
469 |
return
|
470 |
|
@@ -489,10 +505,54 @@ class EmployerBrandingAgent:
|
|
489 |
# Add extracted_date column for cleaner date operations
|
490 |
df_copy['extracted_date'] = df_copy.apply(extract_date_from_category, axis=1)
|
491 |
|
492 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
self.all_dataframes[name] = df_copy
|
494 |
|
495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
|
497 |
def _build_system_prompt(self) -> str:
|
498 |
"""Enhanced system prompt that works with PandasAI integration"""
|
@@ -582,73 +642,110 @@ class EmployerBrandingAgent:
|
|
582 |
|
583 |
# Replace the _generate_pandas_response method and everything after it with this properly indented code:
|
584 |
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
|
|
|
|
|
|
589 |
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
637 |
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
|
|
|
|
|
|
|
|
|
|
643 |
|
644 |
-
|
645 |
-
|
|
|
|
|
|
|
646 |
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
return f"Error processing data query: {str(e)}", False
|
653 |
|
654 |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|
|
|
355 |
logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
|
356 |
|
357 |
def _initialize_pandas_agent(self):
|
358 |
+
"""Initialize PandasAI with enhanced configuration for chart generation"""
|
359 |
if not self.all_dataframes or not GEMINI_API_KEY:
|
360 |
logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key")
|
361 |
return
|
362 |
|
363 |
self._preprocess_dataframes_for_pandas_ai()
|
364 |
+
|
365 |
try:
|
366 |
# Configure LiteLLM with Gemini
|
367 |
llm = LiteLLM(
|
368 |
+
model="gemini/gemini-2.5-flash-preview-05-20",
|
369 |
api_key=GEMINI_API_KEY
|
370 |
)
|
371 |
|
372 |
+
# Enhanced PandasAI configuration for better chart generation
|
373 |
pai.config.set({
|
374 |
+
"llm": llm,
|
375 |
+
"temperature": 0.3, # Lower temperature for more consistent results
|
376 |
+
"verbose": True,
|
377 |
+
"enable_cache": False, # Disable cache to avoid stale results
|
378 |
+
"save_charts": True,
|
379 |
+
"save_charts_path": "./charts",
|
380 |
+
"open_charts": False,
|
381 |
+
"custom_whitelisted_dependencies": [
|
382 |
+
"matplotlib", "seaborn", "plotly", "pandas", "numpy"
|
383 |
+
],
|
384 |
+
"max_retries": 3, # Add retry logic
|
385 |
+
"use_error_correction_framework": True # Enable error correction
|
386 |
})
|
387 |
|
388 |
+
# Store dataframes for chat queries
|
389 |
self.pandas_dfs = {}
|
390 |
for name, df in self.all_dataframes.items():
|
391 |
+
# Skip empty dataframes
|
392 |
+
if df.empty:
|
393 |
+
continue
|
394 |
+
|
395 |
df_description = self._generate_dataframe_description(name, df)
|
396 |
pandas_df = pai.DataFrame(df, description=df_description)
|
397 |
self.pandas_dfs[name] = pandas_df
|
398 |
|
399 |
+
self.pandas_agent = True
|
400 |
logging.info(f"PandasAI initialized successfully with {len(self.pandas_dfs)} DataFrames")
|
401 |
|
402 |
except Exception as e:
|
|
|
405 |
self.pandas_dfs = {}
|
406 |
|
407 |
def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
|
408 |
+
"""Enhanced dataframe description for better PandasAI understanding"""
|
409 |
description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
|
410 |
|
411 |
# Add column descriptions based on common patterns
|
|
|
428 |
if column_descriptions:
|
429 |
description_parts.append("Key columns: " + "; ".join(column_descriptions))
|
430 |
|
431 |
+
# Enhanced context for specific datasets
|
|
|
432 |
if name.lower() in ['follower_stats', 'followers']:
|
433 |
+
description_parts.append("""
|
434 |
+
This data tracks LinkedIn company page follower growth and demographics.
|
435 |
+
For monthly growth analysis, use records where follower_count_type='follower_gains_monthly'.
|
436 |
+
The 'extracted_date' column contains properly formatted dates for time series analysis.
|
437 |
+
Use 'year_month' or 'month_name' columns for better date display in charts.
|
438 |
+
For cumulative analysis, use records where follower_count_type='follower_count_cumulative'.
|
439 |
+
""")
|
440 |
+
elif name.lower().endswith('_monthly_gains'):
|
441 |
+
description_parts.append("""
|
442 |
+
This is a filtered dataset containing only monthly follower gains data.
|
443 |
+
All records have valid dates and are sorted chronologically.
|
444 |
+
Use this for creating time series charts of monthly growth patterns.
|
445 |
+
""")
|
446 |
elif name.lower() in ['posts', 'post_stats']:
|
447 |
description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.")
|
448 |
elif name.lower() in ['mentions', 'brand_mentions']:
|
|
|
480 |
return get_all_schemas_representation(self.all_dataframes)
|
481 |
|
482 |
def _preprocess_dataframes_for_pandas_ai(self):
|
483 |
+
"""Enhanced preprocessing to handle date casting issues and ensure chart generation"""
|
484 |
if not self.all_dataframes:
|
485 |
return
|
486 |
|
|
|
505 |
# Add extracted_date column for cleaner date operations
|
506 |
df_copy['extracted_date'] = df_copy.apply(extract_date_from_category, axis=1)
|
507 |
|
508 |
+
# Convert extracted_date to proper datetime type and handle nulls
|
509 |
+
df_copy['extracted_date'] = pd.to_datetime(df_copy['extracted_date'], errors='coerce')
|
510 |
+
|
511 |
+
# Create additional helper columns for better analysis
|
512 |
+
monthly_mask = df_copy['follower_count_type'] == 'follower_gains_monthly'
|
513 |
+
df_copy.loc[monthly_mask, 'date_for_analysis'] = df_copy.loc[monthly_mask, 'extracted_date']
|
514 |
+
df_copy.loc[monthly_mask, 'year_month'] = df_copy.loc[monthly_mask, 'extracted_date'].dt.strftime('%Y-%m')
|
515 |
+
df_copy.loc[monthly_mask, 'month_name'] = df_copy.loc[monthly_mask, 'extracted_date'].dt.strftime('%B %Y')
|
516 |
+
|
517 |
+
# Ensure follower_count is numeric and handle nulls
|
518 |
+
if 'follower_count' in df_copy.columns:
|
519 |
+
df_copy['follower_count'] = pd.to_numeric(df_copy['follower_count'], errors='coerce')
|
520 |
+
df_copy['follower_count'] = df_copy['follower_count'].fillna(0)
|
521 |
+
|
522 |
+
# Create separate monthly gains dataframe for easier analysis
|
523 |
+
monthly_gains = df_copy[df_copy['follower_count_type'] == 'follower_gains_monthly'].copy()
|
524 |
+
if not monthly_gains.empty:
|
525 |
+
monthly_gains = monthly_gains.dropna(subset=['extracted_date'])
|
526 |
+
monthly_gains = monthly_gains.sort_values('extracted_date')
|
527 |
+
# Store as separate dataframe
|
528 |
+
self.all_dataframes[f'{name}_monthly_gains'] = monthly_gains
|
529 |
+
|
530 |
+
# Update the main dataframe
|
531 |
self.all_dataframes[name] = df_copy
|
532 |
|
533 |
+
logging.info(f"Preprocessed {name} dataframe for date handling. Monthly records: {len(monthly_gains) if not monthly_gains.empty else 0}")
|
534 |
+
|
535 |
+
# General preprocessing for all dataframes
|
536 |
+
df_processed = self.all_dataframes[name].copy()
|
537 |
+
|
538 |
+
# Handle common data quality issues
|
539 |
+
# Convert object columns that should be numeric
|
540 |
+
for col in df_processed.columns:
|
541 |
+
if df_processed[col].dtype == 'object':
|
542 |
+
# Try to convert to numeric if it looks like numbers
|
543 |
+
if df_processed[col].astype(str).str.match(r'^\d+\.?\d*$').any():
|
544 |
+
df_processed[col] = pd.to_numeric(df_processed[col], errors='coerce')
|
545 |
+
|
546 |
+
# Fill nulls in numeric columns with 0 (for charting)
|
547 |
+
numeric_columns = df_processed.select_dtypes(include=[np.number]).columns
|
548 |
+
df_processed[numeric_columns] = df_processed[numeric_columns].fillna(0)
|
549 |
+
|
550 |
+
# Fill nulls in text columns with empty string
|
551 |
+
text_columns = df_processed.select_dtypes(include=['object']).columns
|
552 |
+
df_processed[text_columns] = df_processed[text_columns].fillna('')
|
553 |
+
|
554 |
+
self.all_dataframes[name] = df_processed
|
555 |
+
|
556 |
|
557 |
def _build_system_prompt(self) -> str:
|
558 |
"""Enhanced system prompt that works with PandasAI integration"""
|
|
|
642 |
|
643 |
# Replace the _generate_pandas_response method and everything after it with this properly indented code:
|
644 |
|
645 |
+
async def _generate_pandas_response(self, query: str) -> tuple[str, bool]:
|
646 |
+
"""Generate response using PandasAI with enhanced error handling"""
|
647 |
+
if not self.pandas_agent or not hasattr(self, 'pandas_dfs'):
|
648 |
+
return "Data analysis not available - PandasAI not initialized.", False
|
649 |
+
|
650 |
+
try:
|
651 |
+
logging.info(f"Processing data query with PandasAI: {query[:100]}...")
|
652 |
|
653 |
+
# Clear any existing matplotlib figures
|
654 |
+
import matplotlib.pyplot as plt
|
655 |
+
plt.clf()
|
656 |
+
plt.close('all')
|
657 |
+
|
658 |
+
# Enhanced query processing based on content
|
659 |
+
processed_query = query
|
660 |
+
|
661 |
+
# Add helpful context for common chart requests
|
662 |
+
if any(word in query.lower() for word in ['chart', 'graph', 'plot', 'visualize']):
|
663 |
+
if 'monthly' in query.lower() and 'follower' in query.lower():
|
664 |
+
processed_query += """.
|
665 |
+
Use the monthly gains data (follower_count_type='follower_gains_monthly')
|
666 |
+
and use the extracted_date or month_name column for the x-axis.
|
667 |
+
Make sure to filter out any null dates and sort by date.
|
668 |
+
Create a clear line chart showing the trend over time."""
|
669 |
+
elif 'cumulative' in query.lower() and 'follower' in query.lower():
|
670 |
+
processed_query += """.
|
671 |
+
Use the cumulative data (follower_count_type='follower_count_cumulative')
|
672 |
+
and create a chart showing the total follower growth over time."""
|
673 |
+
|
674 |
+
# Execute the query
|
675 |
+
if len(self.pandas_dfs) == 1:
|
676 |
+
df = list(self.pandas_dfs.values())[0]
|
677 |
+
logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
|
678 |
+
pandas_response = df.chat(processed_query)
|
679 |
+
else:
|
680 |
+
dfs = list(self.pandas_dfs.values())
|
681 |
+
pandas_response = pai.chat(processed_query, *dfs)
|
682 |
+
|
683 |
+
# Enhanced response processing
|
684 |
+
response_text = ""
|
685 |
+
chart_info = ""
|
686 |
+
|
687 |
+
# Check for chart generation
|
688 |
+
chart_path = None
|
689 |
+
|
690 |
+
# Method 1: Direct path response
|
691 |
+
if isinstance(pandas_response, str) and pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')):
|
692 |
+
chart_path = pandas_response
|
693 |
+
response_text = "Analysis completed with visualization"
|
694 |
+
|
695 |
+
# Method 2: Response object with plot path
|
696 |
+
elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path:
|
697 |
+
chart_path = pandas_response.plot_path
|
698 |
+
response_text = getattr(pandas_response, 'text', str(pandas_response))
|
699 |
+
|
700 |
+
# Method 3: Check charts directory for new files
|
701 |
+
else:
|
702 |
+
if os.path.exists(self.charts_dir):
|
703 |
+
# Get all chart files sorted by modification time
|
704 |
+
chart_files = []
|
705 |
+
for f in os.listdir(self.charts_dir):
|
706 |
+
if f.endswith(('.png', '.jpg', '.jpeg', '.svg')):
|
707 |
+
full_path = os.path.join(self.charts_dir, f)
|
708 |
+
chart_files.append((full_path, os.path.getmtime(full_path)))
|
709 |
|
710 |
+
if chart_files:
|
711 |
+
# Sort by modification time (newest first)
|
712 |
+
chart_files.sort(key=lambda x: x[1], reverse=True)
|
713 |
+
latest_chart_path, latest_time = chart_files[0]
|
714 |
+
|
715 |
+
# Check if created in last 60 seconds
|
716 |
+
import time
|
717 |
+
if time.time() - latest_time < 60:
|
718 |
+
chart_path = latest_chart_path
|
719 |
+
logging.info(f"Found recent chart: {chart_path}")
|
720 |
|
721 |
+
# Handle text response
|
722 |
+
if pandas_response and str(pandas_response).strip():
|
723 |
+
response_text = str(pandas_response).strip()
|
724 |
+
else:
|
725 |
+
response_text = "Analysis completed"
|
726 |
|
727 |
+
# Format final response
|
728 |
+
if chart_path and os.path.exists(chart_path):
|
729 |
+
chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(chart_path)}\nChart saved at: {chart_path}"
|
730 |
+
logging.info(f"Chart successfully generated: {chart_path}")
|
731 |
+
|
732 |
+
final_response = response_text + chart_info
|
733 |
+
success = True
|
734 |
+
|
735 |
+
return final_response, success
|
736 |
+
|
737 |
+
except Exception as e:
|
738 |
+
logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
|
739 |
+
|
740 |
+
# Enhanced error handling
|
741 |
+
error_str = str(e).lower()
|
742 |
+
if "matplotlib" in error_str and "none" in error_str:
|
743 |
+
return "I encountered a data visualization error. This might be due to missing or null values in your data. Please try asking for the raw data first, or specify which specific columns you'd like to analyze.", False
|
744 |
+
elif "strftime" in error_str:
|
745 |
+
return "I encountered a date formatting issue. Please try asking for the data without specific date formatting, or ask me to show the raw data structure first.", False
|
746 |
+
elif "ambiguous" in error_str:
|
747 |
+
return "I encountered an ambiguous data type issue. Please try being more specific about which data you'd like to analyze (e.g., 'show monthly follower gains' vs 'show cumulative followers').", False
|
748 |
+
else:
|
749 |
return f"Error processing data query: {str(e)}", False
|
750 |
|
751 |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|