GuglielmoTor commited on
Commit
e7eee4a
·
verified ·
1 Parent(s): cd58acb

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +95 -49
eb_agent_module.py CHANGED
@@ -187,18 +187,33 @@ class AdvancedRAGSystem:
187
  embed_config_payload = None
188
  if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
189
  embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
190
-
191
  response = client.models.embed_content(
192
  model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
193
- contents=text, # Fix: Remove the list wrapper
194
  config=embed_config_payload
195
  )
196
 
197
- # Fix: Update response parsing - use .embeddings directly (it's a list)
198
  if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
199
- # Fix: Access embedding values directly from the list
200
- embedding_values = response.embeddings[0] # This is already the array/list of values
201
- return np.array(embedding_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  else:
203
  logging.error(f"Unexpected response structure")
204
  return None
@@ -258,6 +273,10 @@ class AdvancedRAGSystem:
258
  self.is_initialized = True
259
 
260
  def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
 
 
 
 
261
  if embeddings_matrix.ndim == 1:
262
  embeddings_matrix = embeddings_matrix.reshape(1, -1)
263
  if query_vector.ndim == 1:
@@ -268,7 +287,7 @@ class AdvancedRAGSystem:
268
 
269
  norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
270
  normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
271
-
272
  norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
273
  normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
274
 
@@ -681,6 +700,7 @@ class EmployerBrandingAgent:
681
  and create a chart showing the total follower growth over time."""
682
 
683
  # Execute the query
 
684
  if len(self.pandas_dfs) == 1:
685
  df = list(self.pandas_dfs.values())[0]
686
  logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
@@ -689,59 +709,84 @@ class EmployerBrandingAgent:
689
  dfs = list(self.pandas_dfs.values())
690
  pandas_response = pai.chat(processed_query, *dfs)
691
 
692
- # Enhanced response processing
693
  response_text = ""
694
- chart_info = ""
695
-
696
- # Check for chart generation
697
  chart_path = None
698
 
699
- # Method 1: Direct path response
700
- if isinstance(pandas_response, str) and pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')):
701
- chart_path = pandas_response
702
- response_text = "Analysis completed with visualization"
703
-
704
- # Method 2: Response object with plot path
705
- elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path:
706
- chart_path = pandas_response.plot_path
707
- response_text = getattr(pandas_response, 'text', str(pandas_response))
708
-
709
- # Method 3: Check charts directory for new files
710
- else:
711
- if os.path.exists(self.charts_dir):
712
- # Get all chart files sorted by modification time
713
- chart_files = []
714
- for f in os.listdir(self.charts_dir):
715
- if f.endswith(('.png', '.jpg', '.jpeg', '.svg')):
716
- full_path = os.path.join(self.charts_dir, f)
717
- chart_files.append((full_path, os.path.getmtime(full_path)))
718
-
719
- if chart_files:
720
- # Sort by modification time (newest first)
721
- chart_files.sort(key=lambda x: x[1], reverse=True)
722
- latest_chart_path, latest_time = chart_files[0]
723
-
724
- # Check if created in last 60 seconds
725
- import time
726
- if time.time() - latest_time < 60:
727
- chart_path = latest_chart_path
728
- logging.info(f"Found recent chart: {chart_path}")
 
 
 
 
729
 
730
- # Handle text response
731
- if pandas_response and str(pandas_response).strip():
732
- response_text = str(pandas_response).strip()
733
  else:
734
- response_text = "Analysis completed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735
 
736
  # Format final response
 
 
 
 
737
  if chart_path and os.path.exists(chart_path):
738
  chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(chart_path)}\nChart saved at: {chart_path}"
739
  logging.info(f"Chart successfully generated: {chart_path}")
740
 
741
  final_response = response_text + chart_info
742
- success = True
743
-
744
- return final_response, success
745
 
746
  except Exception as e:
747
  logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
@@ -754,10 +799,11 @@ class EmployerBrandingAgent:
754
  return "I encountered a date formatting issue. Please try asking for the data without specific date formatting, or ask me to show the raw data structure first.", False
755
  elif "ambiguous" in error_str:
756
  return "I encountered an ambiguous data type issue. Please try being more specific about which data you'd like to analyze (e.g., 'show monthly follower gains' vs 'show cumulative followers').", False
 
 
757
  else:
758
  return f"Error processing data query: {str(e)}", False
759
 
760
-
761
  async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
762
  """Generate enhanced response combining PandasAI results with RAG context"""
763
  if not self.is_ready:
 
187
  embed_config_payload = None
188
  if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
189
  embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
190
+
191
  response = client.models.embed_content(
192
  model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
193
+ contents=text,
194
  config=embed_config_payload
195
  )
196
 
197
+ # Fix: Handle ContentEmbedding objects properly
198
  if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
199
+ embedding_obj = response.embeddings[0]
200
+
201
+ # Extract values from ContentEmbedding object
202
+ if hasattr(embedding_obj, 'values'):
203
+ embedding_values = embedding_obj.values
204
+ elif hasattr(embedding_obj, 'embedding'):
205
+ embedding_values = embedding_obj.embedding
206
+ elif isinstance(embedding_obj, (list, tuple)):
207
+ embedding_values = embedding_obj
208
+ else:
209
+ # Try to convert to list/array if it's a different object type
210
+ try:
211
+ embedding_values = list(embedding_obj)
212
+ except:
213
+ logging.error(f"Cannot extract embedding values from object type: {type(embedding_obj)}")
214
+ return None
215
+
216
+ return np.array(embedding_values, dtype=np.float32)
217
  else:
218
  logging.error(f"Unexpected response structure")
219
  return None
 
273
  self.is_initialized = True
274
 
275
  def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
276
+ # Ensure inputs are numpy arrays with proper dtype
277
+ embeddings_matrix = np.asarray(embeddings_matrix, dtype=np.float32)
278
+ query_vector = np.asarray(query_vector, dtype=np.float32)
279
+
280
  if embeddings_matrix.ndim == 1:
281
  embeddings_matrix = embeddings_matrix.reshape(1, -1)
282
  if query_vector.ndim == 1:
 
287
 
288
  norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
289
  normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
290
+
291
  norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
292
  normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
293
 
 
700
  and create a chart showing the total follower growth over time."""
701
 
702
  # Execute the query
703
+ pandas_response = None
704
  if len(self.pandas_dfs) == 1:
705
  df = list(self.pandas_dfs.values())[0]
706
  logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
 
709
  dfs = list(self.pandas_dfs.values())
710
  pandas_response = pai.chat(processed_query, *dfs)
711
 
712
+ # Enhanced response processing with better type handling
713
  response_text = ""
 
 
 
714
  chart_path = None
715
 
716
+ # Handle different response types from PandasAI
717
+ try:
718
+ # Case 1: Direct string response (file path)
719
+ if isinstance(pandas_response, str):
720
+ if pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')):
721
+ chart_path = pandas_response
722
+ response_text = "Analysis completed with visualization"
723
+ else:
724
+ response_text = pandas_response
725
+
726
+ # Case 2: Chart object response
727
+ elif hasattr(pandas_response, 'value') and hasattr(pandas_response, '_get_image'):
728
+ # Handle PandasAI Chart response object
729
+ try:
730
+ # Try to get the chart path without calling show() which causes the error
731
+ if hasattr(pandas_response, 'value'):
732
+ if isinstance(pandas_response.value, str) and pandas_response.value.endswith(('.png', '.jpg', '.jpeg', '.svg')):
733
+ chart_path = pandas_response.value
734
+ response_text = "Analysis completed with visualization"
735
+ elif isinstance(pandas_response.value, dict):
736
+ # Handle dict response from Chart object
737
+ if 'path' in pandas_response.value:
738
+ chart_path = pandas_response.value['path']
739
+ response_text = "Analysis completed with visualization"
740
+ else:
741
+ response_text = "Chart generated but path not accessible"
742
+ except Exception as chart_error:
743
+ logging.warning(f"Error handling chart response: {chart_error}")
744
+ response_text = "Chart generated but encountered display issue"
745
+
746
+ # Case 3: Response with plot_path attribute
747
+ elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path:
748
+ chart_path = pandas_response.plot_path
749
+ response_text = getattr(pandas_response, 'text', "Analysis completed with visualization")
750
 
751
+ # Case 4: Other response types
 
 
752
  else:
753
+ if pandas_response is not None:
754
+ response_text = str(pandas_response).strip()
755
+
756
+ except Exception as response_error:
757
+ logging.warning(f"Error processing PandasAI response: {response_error}")
758
+ response_text = "Analysis completed but encountered response processing issue"
759
+
760
+ # Fallback: Check charts directory for new files if no chart path found
761
+ if not chart_path and os.path.exists(self.charts_dir):
762
+ chart_files = []
763
+ for f in os.listdir(self.charts_dir):
764
+ if f.endswith(('.png', '.jpg', '.jpeg', '.svg')):
765
+ full_path = os.path.join(self.charts_dir, f)
766
+ chart_files.append((full_path, os.path.getmtime(full_path)))
767
+
768
+ if chart_files:
769
+ # Sort by modification time (newest first)
770
+ chart_files.sort(key=lambda x: x[1], reverse=True)
771
+ latest_chart_path, latest_time = chart_files[0]
772
+
773
+ # Check if created in last 60 seconds
774
+ import time
775
+ if time.time() - latest_time < 60:
776
+ chart_path = latest_chart_path
777
+ logging.info(f"Found recent chart: {chart_path}")
778
 
779
  # Format final response
780
+ if not response_text:
781
+ response_text = "Analysis completed"
782
+
783
+ chart_info = ""
784
  if chart_path and os.path.exists(chart_path):
785
  chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(chart_path)}\nChart saved at: {chart_path}"
786
  logging.info(f"Chart successfully generated: {chart_path}")
787
 
788
  final_response = response_text + chart_info
789
+ return final_response, True
 
 
790
 
791
  except Exception as e:
792
  logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
 
799
  return "I encountered a date formatting issue. Please try asking for the data without specific date formatting, or ask me to show the raw data structure first.", False
800
  elif "ambiguous" in error_str:
801
  return "I encountered an ambiguous data type issue. Please try being more specific about which data you'd like to analyze (e.g., 'show monthly follower gains' vs 'show cumulative followers').", False
802
+ elif "startswith" in error_str or "dict" in error_str:
803
+ return "I encountered a response formatting issue. The analysis may have completed but I couldn't process the result properly. Please try rephrasing your query.", False
804
  else:
805
  return f"Error processing data query: {str(e)}", False
806
 
 
807
  async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
808
  """Generate enhanced response combining PandasAI results with RAG context"""
809
  if not self.is_ready: