GuglielmoTor commited on
Commit
7be0087
·
verified ·
1 Parent(s): cf1cd44

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +184 -87
eb_agent_module.py CHANGED
@@ -355,41 +355,48 @@ class EmployerBrandingAgent:
355
  logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
356
 
357
  def _initialize_pandas_agent(self):
358
- """Initialize PandasAI with enhanced configuration"""
359
  if not self.all_dataframes or not GEMINI_API_KEY:
360
  logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key")
361
  return
362
 
363
  self._preprocess_dataframes_for_pandas_ai()
364
-
365
  try:
366
  # Configure LiteLLM with Gemini
367
  llm = LiteLLM(
368
- model="gemini/gemini-2.5-flash-preview-05-20", # Use gemini/ prefix for Gemini API
369
  api_key=GEMINI_API_KEY
370
  )
371
 
372
- # Set PandasAI configuration
373
  pai.config.set({
374
- "llm": llm,
375
- "temperature": 0.7,
376
- "verbose": True,
377
- "enable_cache": True,
378
- "save_charts": True, # Enable chart saving
379
- "save_charts_path": "./charts", # Directory to save charts
380
- "open_charts": False, # Don't auto-open charts in browser
381
- "custom_whitelisted_dependencies": ["matplotlib", "seaborn", "plotly"] # Allow plotting libraries
 
 
 
 
382
  })
383
 
384
- # Store dataframes for chat queries (we'll use them directly)
385
  self.pandas_dfs = {}
386
  for name, df in self.all_dataframes.items():
387
- # Convert to PandasAI DataFrame with description
 
 
 
388
  df_description = self._generate_dataframe_description(name, df)
389
  pandas_df = pai.DataFrame(df, description=df_description)
390
  self.pandas_dfs[name] = pandas_df
391
 
392
- self.pandas_agent = True # Flag to indicate PandasAI is ready
393
  logging.info(f"PandasAI initialized successfully with {len(self.pandas_dfs)} DataFrames")
394
 
395
  except Exception as e:
@@ -398,7 +405,7 @@ class EmployerBrandingAgent:
398
  self.pandas_dfs = {}
399
 
400
  def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
401
- """Generate a descriptive summary for PandasAI to better understand the data"""
402
  description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
403
 
404
  # Add column descriptions based on common patterns
@@ -421,12 +428,21 @@ class EmployerBrandingAgent:
421
  if column_descriptions:
422
  description_parts.append("Key columns: " + "; ".join(column_descriptions))
423
 
424
- # Add specific context for employer branding
425
- # Special handling for follower_stats
426
  if name.lower() in ['follower_stats', 'followers']:
427
- description_parts.append("This data tracks LinkedIn company page follower growth and demographics. For monthly growth data, use the 'extracted_date' column for date-based queries instead of trying to cast 'category_name' as a date.")
428
- if 'extracted_date' in df.columns:
429
- description_parts.append("The 'extracted_date' column contains properly formatted dates (YYYY-MM-DD) extracted from category_name for follower_gains_monthly records.")
 
 
 
 
 
 
 
 
 
 
430
  elif name.lower() in ['posts', 'post_stats']:
431
  description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.")
432
  elif name.lower() in ['mentions', 'brand_mentions']:
@@ -464,7 +480,7 @@ class EmployerBrandingAgent:
464
  return get_all_schemas_representation(self.all_dataframes)
465
 
466
  def _preprocess_dataframes_for_pandas_ai(self):
467
- """Preprocess dataframes to handle date casting issues before PandasAI analysis"""
468
  if not self.all_dataframes:
469
  return
470
 
@@ -489,10 +505,54 @@ class EmployerBrandingAgent:
489
  # Add extracted_date column for cleaner date operations
490
  df_copy['extracted_date'] = df_copy.apply(extract_date_from_category, axis=1)
491
 
492
- # Update the dataframe in our collection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  self.all_dataframes[name] = df_copy
494
 
495
- logging.info(f"Preprocessed {name} dataframe for date handling")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
  def _build_system_prompt(self) -> str:
498
  """Enhanced system prompt that works with PandasAI integration"""
@@ -582,73 +642,110 @@ class EmployerBrandingAgent:
582
 
583
  # Replace the _generate_pandas_response method and everything after it with this properly indented code:
584
 
585
- async def _generate_pandas_response(self, query: str) -> tuple[str, bool]:
586
- """Generate response using PandasAI for data queries"""
587
- if not self.pandas_agent or not hasattr(self, 'pandas_dfs'):
588
- return "Data analysis not available - PandasAI not initialized.", False
 
 
 
589
 
590
- try:
591
- logging.info(f"Processing data query with PandasAI: {query[:100]}...")
592
-
593
- # Clear any existing matplotlib figures to avoid conflicts
594
- import matplotlib.pyplot as plt
595
- plt.clf()
596
- plt.close('all')
597
-
598
- # Use the first available dataframe for single-df queries
599
- if len(self.pandas_dfs) == 1:
600
- df = list(self.pandas_dfs.values())[0]
601
- logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
602
- pandas_response = df.chat(query)
603
- else:
604
- # For multiple dataframes, use pai.chat with all dfs
605
- dfs = list(self.pandas_dfs.values())
606
- pandas_response = pai.chat(query, *dfs)
607
-
608
- # Handle different response types
609
- response_text = ""
610
- chart_info = ""
611
-
612
- # Check if response is a plot path or contains plot information
613
- if isinstance(pandas_response, str) and pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')):
614
- # Response is a chart path
615
- chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(pandas_response)}\nChart saved at: {pandas_response}"
616
- response_text = "Analysis completed with visualization"
617
- logging.info(f"Chart generated: {pandas_response}")
618
- elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path:
619
- # Response object has plot path
620
- chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(pandas_response.plot_path)}\nChart saved at: {pandas_response.plot_path}"
621
- response_text = getattr(pandas_response, 'text', str(pandas_response))
622
- logging.info(f"Chart generated: {pandas_response.plot_path}")
623
- else:
624
- # Check for any new chart files in the charts directory
625
- if os.path.exists(self.charts_dir):
626
- chart_files = [f for f in os.listdir(self.charts_dir) if f.endswith(('.png', '.jpg', '.jpeg', '.svg'))]
627
- if chart_files:
628
- # Get the most recent chart file
629
- chart_files.sort(key=lambda x: os.path.getmtime(os.path.join(self.charts_dir, x)), reverse=True)
630
- latest_chart = chart_files[0]
631
- chart_path = os.path.join(self.charts_dir, latest_chart)
632
- # Check if this chart was created in the last 30 seconds (likely from this query)
633
- import time
634
- if time.time() - os.path.getmtime(chart_path) < 30:
635
- chart_info = f"\n\n📊 **Chart Generated**: {latest_chart}\nChart saved at: {chart_path}"
636
- logging.info(f"Chart generated: {chart_path}")
 
 
 
 
 
 
 
 
 
637
 
638
- # Handle text response
639
- if pandas_response and str(pandas_response).strip():
640
- response_text = str(pandas_response).strip()
641
- else:
642
- response_text = "Analysis completed"
 
 
 
 
 
643
 
644
- final_response = response_text + chart_info
645
- return final_response, True
 
 
 
646
 
647
- except Exception as e:
648
- logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
649
- # Try to provide a more helpful error message
650
- if "Invalid output" in str(e) and "plot save path" in str(e):
651
- return "I tried to create a visualization but encountered a formatting issue. Please try rephrasing your request or ask for specific data without requesting a chart.", False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  return f"Error processing data query: {str(e)}", False
653
 
654
  async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
 
355
  logging.info(f"EnhancedEmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
356
 
357
  def _initialize_pandas_agent(self):
358
+ """Initialize PandasAI with enhanced configuration for chart generation"""
359
  if not self.all_dataframes or not GEMINI_API_KEY:
360
  logging.warning("Cannot initialize PandasAI agent: missing dataframes or API key")
361
  return
362
 
363
  self._preprocess_dataframes_for_pandas_ai()
364
+
365
  try:
366
  # Configure LiteLLM with Gemini
367
  llm = LiteLLM(
368
+ model="gemini/gemini-2.5-flash-preview-05-20",
369
  api_key=GEMINI_API_KEY
370
  )
371
 
372
+ # Enhanced PandasAI configuration for better chart generation
373
  pai.config.set({
374
+ "llm": llm,
375
+ "temperature": 0.3, # Lower temperature for more consistent results
376
+ "verbose": True,
377
+ "enable_cache": False, # Disable cache to avoid stale results
378
+ "save_charts": True,
379
+ "save_charts_path": "./charts",
380
+ "open_charts": False,
381
+ "custom_whitelisted_dependencies": [
382
+ "matplotlib", "seaborn", "plotly", "pandas", "numpy"
383
+ ],
384
+ "max_retries": 3, # Add retry logic
385
+ "use_error_correction_framework": True # Enable error correction
386
  })
387
 
388
+ # Store dataframes for chat queries
389
  self.pandas_dfs = {}
390
  for name, df in self.all_dataframes.items():
391
+ # Skip empty dataframes
392
+ if df.empty:
393
+ continue
394
+
395
  df_description = self._generate_dataframe_description(name, df)
396
  pandas_df = pai.DataFrame(df, description=df_description)
397
  self.pandas_dfs[name] = pandas_df
398
 
399
+ self.pandas_agent = True
400
  logging.info(f"PandasAI initialized successfully with {len(self.pandas_dfs)} DataFrames")
401
 
402
  except Exception as e:
 
405
  self.pandas_dfs = {}
406
 
407
  def _generate_dataframe_description(self, name: str, df: pd.DataFrame) -> str:
408
+ """Enhanced dataframe description for better PandasAI understanding"""
409
  description_parts = [f"This is the '{name}' dataset containing {len(df)} records."]
410
 
411
  # Add column descriptions based on common patterns
 
428
  if column_descriptions:
429
  description_parts.append("Key columns: " + "; ".join(column_descriptions))
430
 
431
+ # Enhanced context for specific datasets
 
432
  if name.lower() in ['follower_stats', 'followers']:
433
+ description_parts.append("""
434
+ This data tracks LinkedIn company page follower growth and demographics.
435
+ For monthly growth analysis, use records where follower_count_type='follower_gains_monthly'.
436
+ The 'extracted_date' column contains properly formatted dates for time series analysis.
437
+ Use 'year_month' or 'month_name' columns for better date display in charts.
438
+ For cumulative analysis, use records where follower_count_type='follower_count_cumulative'.
439
+ """)
440
+ elif name.lower().endswith('_monthly_gains'):
441
+ description_parts.append("""
442
+ This is a filtered dataset containing only monthly follower gains data.
443
+ All records have valid dates and are sorted chronologically.
444
+ Use this for creating time series charts of monthly growth patterns.
445
+ """)
446
  elif name.lower() in ['posts', 'post_stats']:
447
  description_parts.append("This data contains LinkedIn post performance metrics for employer branding content analysis.")
448
  elif name.lower() in ['mentions', 'brand_mentions']:
 
480
  return get_all_schemas_representation(self.all_dataframes)
481
 
482
  def _preprocess_dataframes_for_pandas_ai(self):
483
+ """Enhanced preprocessing to handle date casting issues and ensure chart generation"""
484
  if not self.all_dataframes:
485
  return
486
 
 
505
  # Add extracted_date column for cleaner date operations
506
  df_copy['extracted_date'] = df_copy.apply(extract_date_from_category, axis=1)
507
 
508
+ # Convert extracted_date to proper datetime type and handle nulls
509
+ df_copy['extracted_date'] = pd.to_datetime(df_copy['extracted_date'], errors='coerce')
510
+
511
+ # Create additional helper columns for better analysis
512
+ monthly_mask = df_copy['follower_count_type'] == 'follower_gains_monthly'
513
+ df_copy.loc[monthly_mask, 'date_for_analysis'] = df_copy.loc[monthly_mask, 'extracted_date']
514
+ df_copy.loc[monthly_mask, 'year_month'] = df_copy.loc[monthly_mask, 'extracted_date'].dt.strftime('%Y-%m')
515
+ df_copy.loc[monthly_mask, 'month_name'] = df_copy.loc[monthly_mask, 'extracted_date'].dt.strftime('%B %Y')
516
+
517
+ # Ensure follower_count is numeric and handle nulls
518
+ if 'follower_count' in df_copy.columns:
519
+ df_copy['follower_count'] = pd.to_numeric(df_copy['follower_count'], errors='coerce')
520
+ df_copy['follower_count'] = df_copy['follower_count'].fillna(0)
521
+
522
+ # Create separate monthly gains dataframe for easier analysis
523
+ monthly_gains = df_copy[df_copy['follower_count_type'] == 'follower_gains_monthly'].copy()
524
+ if not monthly_gains.empty:
525
+ monthly_gains = monthly_gains.dropna(subset=['extracted_date'])
526
+ monthly_gains = monthly_gains.sort_values('extracted_date')
527
+ # Store as separate dataframe
528
+ self.all_dataframes[f'{name}_monthly_gains'] = monthly_gains
529
+
530
+ # Update the main dataframe
531
  self.all_dataframes[name] = df_copy
532
 
533
+ logging.info(f"Preprocessed {name} dataframe for date handling. Monthly records: {len(monthly_gains) if not monthly_gains.empty else 0}")
534
+
535
+ # General preprocessing for all dataframes
536
+ df_processed = self.all_dataframes[name].copy()
537
+
538
+ # Handle common data quality issues
539
+ # Convert object columns that should be numeric
540
+ for col in df_processed.columns:
541
+ if df_processed[col].dtype == 'object':
542
+ # Try to convert to numeric if it looks like numbers
543
+ if df_processed[col].astype(str).str.match(r'^\d+\.?\d*$').any():
544
+ df_processed[col] = pd.to_numeric(df_processed[col], errors='coerce')
545
+
546
+ # Fill nulls in numeric columns with 0 (for charting)
547
+ numeric_columns = df_processed.select_dtypes(include=[np.number]).columns
548
+ df_processed[numeric_columns] = df_processed[numeric_columns].fillna(0)
549
+
550
+ # Fill nulls in text columns with empty string
551
+ text_columns = df_processed.select_dtypes(include=['object']).columns
552
+ df_processed[text_columns] = df_processed[text_columns].fillna('')
553
+
554
+ self.all_dataframes[name] = df_processed
555
+
556
 
557
  def _build_system_prompt(self) -> str:
558
  """Enhanced system prompt that works with PandasAI integration"""
 
642
 
643
  # Replace the _generate_pandas_response method and everything after it with this properly indented code:
644
 
645
+ async def _generate_pandas_response(self, query: str) -> tuple[str, bool]:
646
+ """Generate response using PandasAI with enhanced error handling"""
647
+ if not self.pandas_agent or not hasattr(self, 'pandas_dfs'):
648
+ return "Data analysis not available - PandasAI not initialized.", False
649
+
650
+ try:
651
+ logging.info(f"Processing data query with PandasAI: {query[:100]}...")
652
 
653
+ # Clear any existing matplotlib figures
654
+ import matplotlib.pyplot as plt
655
+ plt.clf()
656
+ plt.close('all')
657
+
658
+ # Enhanced query processing based on content
659
+ processed_query = query
660
+
661
+ # Add helpful context for common chart requests
662
+ if any(word in query.lower() for word in ['chart', 'graph', 'plot', 'visualize']):
663
+ if 'monthly' in query.lower() and 'follower' in query.lower():
664
+ processed_query += """.
665
+ Use the monthly gains data (follower_count_type='follower_gains_monthly')
666
+ and use the extracted_date or month_name column for the x-axis.
667
+ Make sure to filter out any null dates and sort by date.
668
+ Create a clear line chart showing the trend over time."""
669
+ elif 'cumulative' in query.lower() and 'follower' in query.lower():
670
+ processed_query += """.
671
+ Use the cumulative data (follower_count_type='follower_count_cumulative')
672
+ and create a chart showing the total follower growth over time."""
673
+
674
+ # Execute the query
675
+ if len(self.pandas_dfs) == 1:
676
+ df = list(self.pandas_dfs.values())[0]
677
+ logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
678
+ pandas_response = df.chat(processed_query)
679
+ else:
680
+ dfs = list(self.pandas_dfs.values())
681
+ pandas_response = pai.chat(processed_query, *dfs)
682
+
683
+ # Enhanced response processing
684
+ response_text = ""
685
+ chart_info = ""
686
+
687
+ # Check for chart generation
688
+ chart_path = None
689
+
690
+ # Method 1: Direct path response
691
+ if isinstance(pandas_response, str) and pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')):
692
+ chart_path = pandas_response
693
+ response_text = "Analysis completed with visualization"
694
+
695
+ # Method 2: Response object with plot path
696
+ elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path:
697
+ chart_path = pandas_response.plot_path
698
+ response_text = getattr(pandas_response, 'text', str(pandas_response))
699
+
700
+ # Method 3: Check charts directory for new files
701
+ else:
702
+ if os.path.exists(self.charts_dir):
703
+ # Get all chart files sorted by modification time
704
+ chart_files = []
705
+ for f in os.listdir(self.charts_dir):
706
+ if f.endswith(('.png', '.jpg', '.jpeg', '.svg')):
707
+ full_path = os.path.join(self.charts_dir, f)
708
+ chart_files.append((full_path, os.path.getmtime(full_path)))
709
 
710
+ if chart_files:
711
+ # Sort by modification time (newest first)
712
+ chart_files.sort(key=lambda x: x[1], reverse=True)
713
+ latest_chart_path, latest_time = chart_files[0]
714
+
715
+ # Check if created in last 60 seconds
716
+ import time
717
+ if time.time() - latest_time < 60:
718
+ chart_path = latest_chart_path
719
+ logging.info(f"Found recent chart: {chart_path}")
720
 
721
+ # Handle text response
722
+ if pandas_response and str(pandas_response).strip():
723
+ response_text = str(pandas_response).strip()
724
+ else:
725
+ response_text = "Analysis completed"
726
 
727
+ # Format final response
728
+ if chart_path and os.path.exists(chart_path):
729
+ chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(chart_path)}\nChart saved at: {chart_path}"
730
+ logging.info(f"Chart successfully generated: {chart_path}")
731
+
732
+ final_response = response_text + chart_info
733
+ success = True
734
+
735
+ return final_response, success
736
+
737
+ except Exception as e:
738
+ logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
739
+
740
+ # Enhanced error handling
741
+ error_str = str(e).lower()
742
+ if "matplotlib" in error_str and "none" in error_str:
743
+ return "I encountered a data visualization error. This might be due to missing or null values in your data. Please try asking for the raw data first, or specify which specific columns you'd like to analyze.", False
744
+ elif "strftime" in error_str:
745
+ return "I encountered a date formatting issue. Please try asking for the data without specific date formatting, or ask me to show the raw data structure first.", False
746
+ elif "ambiguous" in error_str:
747
+ return "I encountered an ambiguous data type issue. Please try being more specific about which data you'd like to analyze (e.g., 'show monthly follower gains' vs 'show cumulative followers').", False
748
+ else:
749
  return f"Error processing data query: {str(e)}", False
750
 
751
  async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str: