Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +95 -49
eb_agent_module.py
CHANGED
@@ -187,18 +187,33 @@ class AdvancedRAGSystem:
|
|
187 |
embed_config_payload = None
|
188 |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
|
189 |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
190 |
-
|
191 |
response = client.models.embed_content(
|
192 |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
|
193 |
-
contents=text,
|
194 |
config=embed_config_payload
|
195 |
)
|
196 |
|
197 |
-
# Fix:
|
198 |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
else:
|
203 |
logging.error(f"Unexpected response structure")
|
204 |
return None
|
@@ -258,6 +273,10 @@ class AdvancedRAGSystem:
|
|
258 |
self.is_initialized = True
|
259 |
|
260 |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
|
|
|
261 |
if embeddings_matrix.ndim == 1:
|
262 |
embeddings_matrix = embeddings_matrix.reshape(1, -1)
|
263 |
if query_vector.ndim == 1:
|
@@ -268,7 +287,7 @@ class AdvancedRAGSystem:
|
|
268 |
|
269 |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
|
270 |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
|
271 |
-
|
272 |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
|
273 |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
|
274 |
|
@@ -681,6 +700,7 @@ class EmployerBrandingAgent:
|
|
681 |
and create a chart showing the total follower growth over time."""
|
682 |
|
683 |
# Execute the query
|
|
|
684 |
if len(self.pandas_dfs) == 1:
|
685 |
df = list(self.pandas_dfs.values())[0]
|
686 |
logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
|
@@ -689,59 +709,84 @@ class EmployerBrandingAgent:
|
|
689 |
dfs = list(self.pandas_dfs.values())
|
690 |
pandas_response = pai.chat(processed_query, *dfs)
|
691 |
|
692 |
-
# Enhanced response processing
|
693 |
response_text = ""
|
694 |
-
chart_info = ""
|
695 |
-
|
696 |
-
# Check for chart generation
|
697 |
chart_path = None
|
698 |
|
699 |
-
#
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
|
|
|
|
|
|
|
|
729 |
|
730 |
-
#
|
731 |
-
if pandas_response and str(pandas_response).strip():
|
732 |
-
response_text = str(pandas_response).strip()
|
733 |
else:
|
734 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
735 |
|
736 |
# Format final response
|
|
|
|
|
|
|
|
|
737 |
if chart_path and os.path.exists(chart_path):
|
738 |
chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(chart_path)}\nChart saved at: {chart_path}"
|
739 |
logging.info(f"Chart successfully generated: {chart_path}")
|
740 |
|
741 |
final_response = response_text + chart_info
|
742 |
-
|
743 |
-
|
744 |
-
return final_response, success
|
745 |
|
746 |
except Exception as e:
|
747 |
logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
|
@@ -754,10 +799,11 @@ class EmployerBrandingAgent:
|
|
754 |
return "I encountered a date formatting issue. Please try asking for the data without specific date formatting, or ask me to show the raw data structure first.", False
|
755 |
elif "ambiguous" in error_str:
|
756 |
return "I encountered an ambiguous data type issue. Please try being more specific about which data you'd like to analyze (e.g., 'show monthly follower gains' vs 'show cumulative followers').", False
|
|
|
|
|
757 |
else:
|
758 |
return f"Error processing data query: {str(e)}", False
|
759 |
|
760 |
-
|
761 |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|
762 |
"""Generate enhanced response combining PandasAI results with RAG context"""
|
763 |
if not self.is_ready:
|
|
|
187 |
embed_config_payload = None
|
188 |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'):
|
189 |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
190 |
+
|
191 |
response = client.models.embed_content(
|
192 |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
|
193 |
+
contents=text,
|
194 |
config=embed_config_payload
|
195 |
)
|
196 |
|
197 |
+
# Fix: Handle ContentEmbedding objects properly
|
198 |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
|
199 |
+
embedding_obj = response.embeddings[0]
|
200 |
+
|
201 |
+
# Extract values from ContentEmbedding object
|
202 |
+
if hasattr(embedding_obj, 'values'):
|
203 |
+
embedding_values = embedding_obj.values
|
204 |
+
elif hasattr(embedding_obj, 'embedding'):
|
205 |
+
embedding_values = embedding_obj.embedding
|
206 |
+
elif isinstance(embedding_obj, (list, tuple)):
|
207 |
+
embedding_values = embedding_obj
|
208 |
+
else:
|
209 |
+
# Try to convert to list/array if it's a different object type
|
210 |
+
try:
|
211 |
+
embedding_values = list(embedding_obj)
|
212 |
+
except:
|
213 |
+
logging.error(f"Cannot extract embedding values from object type: {type(embedding_obj)}")
|
214 |
+
return None
|
215 |
+
|
216 |
+
return np.array(embedding_values, dtype=np.float32)
|
217 |
else:
|
218 |
logging.error(f"Unexpected response structure")
|
219 |
return None
|
|
|
273 |
self.is_initialized = True
|
274 |
|
275 |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
|
276 |
+
# Ensure inputs are numpy arrays with proper dtype
|
277 |
+
embeddings_matrix = np.asarray(embeddings_matrix, dtype=np.float32)
|
278 |
+
query_vector = np.asarray(query_vector, dtype=np.float32)
|
279 |
+
|
280 |
if embeddings_matrix.ndim == 1:
|
281 |
embeddings_matrix = embeddings_matrix.reshape(1, -1)
|
282 |
if query_vector.ndim == 1:
|
|
|
287 |
|
288 |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
|
289 |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
|
290 |
+
|
291 |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
|
292 |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
|
293 |
|
|
|
700 |
and create a chart showing the total follower growth over time."""
|
701 |
|
702 |
# Execute the query
|
703 |
+
pandas_response = None
|
704 |
if len(self.pandas_dfs) == 1:
|
705 |
df = list(self.pandas_dfs.values())[0]
|
706 |
logging.info(f"Using single DataFrame for query with shape: {df.df.shape}")
|
|
|
709 |
dfs = list(self.pandas_dfs.values())
|
710 |
pandas_response = pai.chat(processed_query, *dfs)
|
711 |
|
712 |
+
# Enhanced response processing with better type handling
|
713 |
response_text = ""
|
|
|
|
|
|
|
714 |
chart_path = None
|
715 |
|
716 |
+
# Handle different response types from PandasAI
|
717 |
+
try:
|
718 |
+
# Case 1: Direct string response (file path)
|
719 |
+
if isinstance(pandas_response, str):
|
720 |
+
if pandas_response.endswith(('.png', '.jpg', '.jpeg', '.svg')):
|
721 |
+
chart_path = pandas_response
|
722 |
+
response_text = "Analysis completed with visualization"
|
723 |
+
else:
|
724 |
+
response_text = pandas_response
|
725 |
+
|
726 |
+
# Case 2: Chart object response
|
727 |
+
elif hasattr(pandas_response, 'value') and hasattr(pandas_response, '_get_image'):
|
728 |
+
# Handle PandasAI Chart response object
|
729 |
+
try:
|
730 |
+
# Try to get the chart path without calling show() which causes the error
|
731 |
+
if hasattr(pandas_response, 'value'):
|
732 |
+
if isinstance(pandas_response.value, str) and pandas_response.value.endswith(('.png', '.jpg', '.jpeg', '.svg')):
|
733 |
+
chart_path = pandas_response.value
|
734 |
+
response_text = "Analysis completed with visualization"
|
735 |
+
elif isinstance(pandas_response.value, dict):
|
736 |
+
# Handle dict response from Chart object
|
737 |
+
if 'path' in pandas_response.value:
|
738 |
+
chart_path = pandas_response.value['path']
|
739 |
+
response_text = "Analysis completed with visualization"
|
740 |
+
else:
|
741 |
+
response_text = "Chart generated but path not accessible"
|
742 |
+
except Exception as chart_error:
|
743 |
+
logging.warning(f"Error handling chart response: {chart_error}")
|
744 |
+
response_text = "Chart generated but encountered display issue"
|
745 |
+
|
746 |
+
# Case 3: Response with plot_path attribute
|
747 |
+
elif hasattr(pandas_response, 'plot_path') and pandas_response.plot_path:
|
748 |
+
chart_path = pandas_response.plot_path
|
749 |
+
response_text = getattr(pandas_response, 'text', "Analysis completed with visualization")
|
750 |
|
751 |
+
# Case 4: Other response types
|
|
|
|
|
752 |
else:
|
753 |
+
if pandas_response is not None:
|
754 |
+
response_text = str(pandas_response).strip()
|
755 |
+
|
756 |
+
except Exception as response_error:
|
757 |
+
logging.warning(f"Error processing PandasAI response: {response_error}")
|
758 |
+
response_text = "Analysis completed but encountered response processing issue"
|
759 |
+
|
760 |
+
# Fallback: Check charts directory for new files if no chart path found
|
761 |
+
if not chart_path and os.path.exists(self.charts_dir):
|
762 |
+
chart_files = []
|
763 |
+
for f in os.listdir(self.charts_dir):
|
764 |
+
if f.endswith(('.png', '.jpg', '.jpeg', '.svg')):
|
765 |
+
full_path = os.path.join(self.charts_dir, f)
|
766 |
+
chart_files.append((full_path, os.path.getmtime(full_path)))
|
767 |
+
|
768 |
+
if chart_files:
|
769 |
+
# Sort by modification time (newest first)
|
770 |
+
chart_files.sort(key=lambda x: x[1], reverse=True)
|
771 |
+
latest_chart_path, latest_time = chart_files[0]
|
772 |
+
|
773 |
+
# Check if created in last 60 seconds
|
774 |
+
import time
|
775 |
+
if time.time() - latest_time < 60:
|
776 |
+
chart_path = latest_chart_path
|
777 |
+
logging.info(f"Found recent chart: {chart_path}")
|
778 |
|
779 |
# Format final response
|
780 |
+
if not response_text:
|
781 |
+
response_text = "Analysis completed"
|
782 |
+
|
783 |
+
chart_info = ""
|
784 |
if chart_path and os.path.exists(chart_path):
|
785 |
chart_info = f"\n\n📊 **Chart Generated**: {os.path.basename(chart_path)}\nChart saved at: {chart_path}"
|
786 |
logging.info(f"Chart successfully generated: {chart_path}")
|
787 |
|
788 |
final_response = response_text + chart_info
|
789 |
+
return final_response, True
|
|
|
|
|
790 |
|
791 |
except Exception as e:
|
792 |
logging.error(f"Error in PandasAI processing: {e}", exc_info=True)
|
|
|
799 |
return "I encountered a date formatting issue. Please try asking for the data without specific date formatting, or ask me to show the raw data structure first.", False
|
800 |
elif "ambiguous" in error_str:
|
801 |
return "I encountered an ambiguous data type issue. Please try being more specific about which data you'd like to analyze (e.g., 'show monthly follower gains' vs 'show cumulative followers').", False
|
802 |
+
elif "startswith" in error_str or "dict" in error_str:
|
803 |
+
return "I encountered a response formatting issue. The analysis may have completed but I couldn't process the result properly. Please try rephrasing your query.", False
|
804 |
else:
|
805 |
return f"Error processing data query: {str(e)}", False
|
806 |
|
|
|
807 |
async def _generate_enhanced_response(self, query: str, pandas_result: str = "", query_type: str = "general") -> str:
|
808 |
"""Generate enhanced response combining PandasAI results with RAG context"""
|
809 |
if not self.is_ready:
|