Spaces:
Running
Running
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap | |
try: | |
from google import generativeai as genai | |
from google.generativeai import types as genai_types # For GenerateContentConfig, SafetySetting etc. | |
from google.generativeai.types import HarmCategory, HarmBlockThreshold # Specific enums | |
except ImportError: | |
logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True) | |
# Define dummy classes/variables if import fails, so app.py can try to run | |
# (though app.py already has EB_AGENT_AVAILABLE check) | |
class genai: Client = None # type: ignore | |
class genai_types: # type: ignore | |
EmbedContentConfig = None | |
GenerateContentConfig = None | |
SafetySetting = None | |
class HarmCategory: # type: ignore | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: # type: ignore | |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
BLOCK_NONE = "BLOCK_NONE" | |
# --- Configuration Constants --- | |
# These are defined here because app.py imports them. | |
# User should ensure these are appropriate for their needs. | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
if not GEMINI_API_KEY: | |
logging.warning("GEMINI_API_KEY environment variable not set. EB Agent will not function.") | |
# Model names (as used in app.py imports from this module) | |
LLM_MODEL_NAME = "gemini-1.5-flash-latest" # Changed to 1.5-flash as it's generally preferred; user had 2.0-flash. Adjust if needed. | |
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Common embedding model; user had gemini-embedding-exp-03-07. Adjust if needed. | |
# Default Generation Config (app.py imports this as EB_AGENT_GEN_CONFIG) | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 8192, | |
"candidate_count": 1, # Important for non-streaming | |
# "stop_sequences": [...] # Optional | |
} | |
# Default Safety Settings (app.py imports this as EB_AGENT_SAFETY_SETTINGS) | |
DEFAULT_SAFETY_SETTINGS = [ | |
{"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}, | |
{"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}, | |
{"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}, | |
{"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}, | |
] | |
# Placeholder for RAG documents DataFrame (app.py imports this as eb_agent_default_rag_docs) | |
# In a real application, this would be loaded from a file or database. | |
df_rag_documents = pd.DataFrame({ | |
'text': [ | |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.", | |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.", | |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.", | |
"Analyzing follower demographics and post engagement helps refine employer branding strategies." | |
] | |
}) | |
# --- Client Initialization --- | |
# This client will be used by the agent instances. | |
# It's initialized once when the module is loaded. | |
client = None | |
if GEMINI_API_KEY and genai.Client: # Check if genai.Client is not None (due to dummy class on import error) | |
try: | |
# genai.configure(api_key=GEMINI_API_KEY) # Alternative: global configuration | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
logging.info("Google GenAI client initialized successfully.") | |
except Exception as e: | |
logging.error(f"Failed to initialize Google GenAI client: {e}", exc_info=True) | |
else: | |
logging.warning("Google GenAI client could not be initialized (GEMINI_API_KEY missing or library import failed).") | |
class AdvancedRAGSystem: | |
""" | |
Handles Retrieval Augmented Generation by embedding documents and finding relevant context for queries. | |
""" | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.documents_df = documents_df.copy() # Work on a copy | |
self.embedding_model_name = embedding_model_name | |
self.embeddings: np.ndarray | None = None # Populated by async initialize_embeddings | |
logging.info(f"AdvancedRAGSystem initialized with embedding model: {self.embedding_model_name}") | |
def _embed_single_document_sync(self, text: str) -> np.ndarray: | |
"""Synchronous helper to embed a single piece of text.""" | |
if not client: | |
raise ConnectionError("GenAI client not initialized for RAG embedding.") | |
if not text or not isinstance(text, str): # Basic validation | |
logging.warning("Attempted to embed empty or non-string text. Returning zero vector.") | |
# Attempt to get model's embedding dimension, otherwise use a common default (e.g., 768) | |
# This is tricky without a live model call. For now, let's assume it will be filtered or handled. | |
# If we must return a vector, its dimensionality needs to be known. | |
# For simplicity, errors during embedding will be logged and might lead to skipping the doc. | |
raise ValueError("Cannot embed empty or non-string text.") | |
# Using client.models.embed_content as per user's provided snippets | |
response = client.models.embed_content( | |
model=self.embedding_model_name, # e.g., "text-embedding-004" or "gemini-embedding-exp-03-07" | |
contents=text, # API takes 'contents' (plural) but can be a single string for single embedding | |
config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") if genai_types.EmbedContentConfig else None | |
) | |
# Assuming response.embeddings is the list of floats for a single content string, as per user's snippet. | |
return np.array(response.embeddings) | |
async def initialize_embeddings(self): | |
"""Asynchronously embeds all documents in the documents_df. Should be called once.""" | |
if self.documents_df.empty: | |
logging.info("RAG documents DataFrame is empty. No embeddings to initialize.") | |
self.embeddings = np.array([]) | |
return | |
if not client: | |
logging.error("GenAI client not available for RAG embedding initialization.") | |
self.embeddings = np.array([]) | |
return | |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...") | |
embedded_docs_list = [] | |
for index, row in self.documents_df.iterrows(): | |
text_to_embed = row.get('text') | |
if not text_to_embed or not isinstance(text_to_embed, str): | |
logging.warning(f"Skipping document at index {index} due to invalid text: {text_to_embed}") | |
continue | |
try: | |
# Wrap the synchronous SDK call in asyncio.to_thread | |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed) | |
embedded_docs_list.append(embedding_array) | |
except Exception as e: | |
logging.error(f"Error embedding document text (index {index}) '{str(text_to_embed)[:50]}...': {e}", exc_info=False) # exc_info=False for brevity in loop | |
if not embedded_docs_list: | |
self.embeddings = np.array([]) | |
logging.warning("No documents were successfully embedded for RAG.") | |
else: | |
try: | |
self.embeddings = np.vstack(embedded_docs_list) | |
logging.info(f"Successfully embedded {len(embedded_docs_list)} documents for RAG. Embedding matrix shape: {self.embeddings.shape}") | |
except ValueError as ve: # Handles cases like empty list or inconsistent shapes if errors weren't caught properly | |
logging.error(f"Error stacking embeddings: {ve}. Check individual embedding errors.", exc_info=True) | |
self.embeddings = np.array([]) | |
async def retrieve_relevant_info(self, query: str, top_k: int = 3) -> str: | |
"""Retrieves relevant document snippets for a given query using vector similarity.""" | |
if self.embeddings is None or self.embeddings.size == 0 or self.documents_df.empty: | |
logging.debug("RAG system not initialized or no documents/embeddings available for retrieval.") | |
return "" | |
if not query or not isinstance(query, str): | |
logging.debug("Empty or invalid query for RAG retrieval.") | |
return "" | |
if not client: | |
logging.error("GenAI client not available for RAG query embedding.") | |
return "" | |
try: | |
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) | |
except Exception as e: | |
logging.error(f"Error embedding query '{str(query)[:50]}...': {e}", exc_info=False) | |
return "" | |
if query_vector.ndim == 0 or query_vector.size == 0: | |
logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}") | |
return "" | |
if query_vector.ndim > 1: # Should be 1D | |
query_vector = query_vector.flatten() | |
try: | |
# Cosine similarity is dot product of normalized vectors. | |
# For simplicity, using dot product directly. Normalize if true cosine sim is needed. | |
scores = np.dot(self.embeddings, query_vector) # self.embeddings (N, D), query_vector (D,) -> scores (N,) | |
if scores.size == 0: | |
return "" | |
actual_top_k = min(top_k, len(self.documents_df), len(scores)) | |
if actual_top_k <= 0: return "" # Ensure top_k is positive | |
# Get indices of top_k scores in descending order | |
top_indices = np.argsort(scores)[-actual_top_k:][::-1] | |
valid_top_indices = [idx for idx in top_indices if 0 <= idx < len(self.documents_df)] | |
if not valid_top_indices: return "" | |
# Retrieve the 'text' field from the original DataFrame | |
context_parts = [self.documents_df.iloc[i]['text'] for i in valid_top_indices if 'text' in self.documents_df.columns] | |
context = "\n\n---\n\n".join(context_parts) | |
logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...") | |
return context | |
except Exception as e: | |
logging.error(f"Error during RAG retrieval (dot product/sorting): {e}", exc_info=True) | |
return "" | |
class EmployerBrandingAgent: | |
""" | |
An agent that uses Generative AI to provide insights on employer branding | |
based on provided DataFrames and RAG context. | |
""" | |
def __init__(self, | |
all_dataframes: dict, | |
rag_documents_df: pd.DataFrame, # For RAG system | |
llm_model_name: str, | |
embedding_model_name: str, # For RAG system | |
generation_config_dict: dict, | |
safety_settings_list_of_dicts: list, | |
# client_instance, # Using global client for simplicity now | |
force_sandbox: bool = False # Parameter from app.py, currently unused here | |
): | |
# self.client = client_instance # If client were passed | |
self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()} # Work with copies | |
self.schemas_representation = self._get_all_schemas_representation() # Sync method | |
self.chat_history = [] # Stores chat in API format: [{"role": "user/model", "parts": [{"text": "..."}]}] | |
# This will be set by app.py before calling process_query | |
self.llm_model_name = llm_model_name | |
self.generation_config_dict = generation_config_dict | |
self.safety_settings_list_of_dicts = safety_settings_list_of_dicts | |
self.embedding_model_name = embedding_model_name | |
self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name) | |
# Note: self.rag_system.initialize_embeddings() must be called externally (e.g., in app.py) | |
self.force_sandbox = force_sandbox # Store if needed for tool use later | |
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. RAG system created.") | |
def _get_all_schemas_representation(self) -> str: | |
"""Generates a string representation of the schemas of all DataFrames.""" | |
schema_descriptions = ["DataFrames available for analysis:"] | |
for key, df in self.all_dataframes.items(): | |
df_name = f"df_{key}" # Consistent naming for the agent to refer to | |
columns = ", ".join(df.columns) | |
shape = df.shape | |
if df.empty: | |
schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}" | |
else: | |
# Basic stats for numeric columns, first few unique for objects | |
sample_info_parts = [] | |
for col in df.columns: | |
if pd.api.types.is_numeric_dtype(df[col]) and not df[col].empty: | |
sample_info_parts.append(f"{col} (numeric, e.g., mean: {df[col].mean():.2f})") | |
elif pd.api.types.is_datetime64_any_dtype(df[col]) and not df[col].empty: | |
sample_info_parts.append(f"{col} (datetime, e.g., min: {df[col].min()}, max: {df[col].max()})") | |
elif not df[col].empty: | |
unique_vals = df[col].unique() | |
display_unique = ', '.join(map(str, unique_vals[:3])) | |
if len(unique_vals) > 3: display_unique += ", ..." | |
sample_info_parts.append(f"{col} (object, e.g., {display_unique})") | |
else: | |
sample_info_parts.append(f"{col} (empty)") | |
schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns & Sample Info:\n " + "\n ".join(sample_info_parts)) | |
schema_descriptions.append(schema) | |
return "\n".join(schema_descriptions) | |
async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str: | |
""" | |
Constructs the full prompt for the current turn, including system instructions, | |
DataFrame schemas, RAG context, and the user's query. | |
""" | |
# System instruction part | |
prompt_parts = [ | |
"You are an expert Employer Branding Analyst and a helpful AI assistant. " | |
"Your goal is to provide insightful analysis based on the provided LinkedIn data. " | |
"When asked to generate Pandas code, ensure it is correct, runnable, and clearly explained. " | |
"When providing insights, be specific and refer to the data where possible." | |
] | |
# Schema information | |
prompt_parts.append("\n\n--- AVAILABLE DATA ---") | |
prompt_parts.append(self.schemas_representation) | |
# RAG context | |
if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0 : # Check if RAG is initialized | |
logging.debug(f"Retrieving RAG context for query: {raw_user_query[:50]}...") | |
rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query) | |
if rag_context: | |
prompt_parts.append("\n\n--- RELEVANT CONTEXTUAL INFORMATION (from documents) ---") | |
prompt_parts.append(rag_context) | |
else: | |
logging.debug("No relevant RAG context found.") | |
else: | |
logging.debug("RAG system not initialized or embeddings not available, skipping RAG context retrieval.") | |
# User's current query | |
prompt_parts.append("\n\n--- USER REQUEST ---") | |
prompt_parts.append(f"Based on all the information above, please respond to the following user query:\n{raw_user_query}") | |
final_prompt = "\n".join(prompt_parts) | |
logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}") | |
return final_prompt | |
async def process_query(self, raw_user_query_this_turn: str) -> str: | |
""" | |
Processes the user's query, incorporating chat history, DataFrame schemas, and RAG. | |
The agent's self.chat_history is expected to be set by the calling application (app.py) | |
and should contain the history *before* the current raw_user_query_this_turn. | |
This method returns the AI's response string. app.py will then update the agent's | |
chat history with the raw_user_query_this_turn and this response. | |
""" | |
if not client: | |
logging.error("GenAI client not initialized. Cannot process query.") | |
return "Error: The AI Agent is not available due to a configuration issue with the AI service." | |
if not raw_user_query_this_turn.strip(): | |
return "Please provide a query." | |
# 1. Prepare the augmented prompt for the *current* user query | |
# This prompt includes system instructions, schemas, RAG, and the current raw query. | |
augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn) | |
# 2. Construct the full list of contents for the API call | |
# self.chat_history should be in API format: [{"role": "user/model", "parts": [{"text": "..."}]}] | |
# It contains history *before* the current raw_user_query_this_turn. | |
api_call_contents = [] | |
if self.chat_history: # Add previous turns if any | |
api_call_contents.extend(self.chat_history) | |
# Add the current user turn, using the fully augmented prompt as its content | |
api_call_contents.append({"role": "user", "parts": [{"text": augmented_current_user_prompt_text}]}) | |
logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}") | |
if api_call_contents: | |
logging.debug(f"Last turn role: {api_call_contents[-1]['role']}, text start: {api_call_contents[-1]['parts'][0]['text'][:100]}") | |
# 3. Prepare API configuration | |
# Convert safety settings from list of dicts to list of SafetySetting objects if genai_types are available | |
api_safety_settings = [] | |
if genai_types.SafetySetting: | |
for ss_dict in self.safety_settings_list_of_dicts: | |
try: | |
api_safety_settings.append(genai_types.SafetySetting(**ss_dict)) | |
except TypeError: # Handles if HarmCategory/HarmBlockThreshold were strings due to import error | |
logging.warning(f"Could not create SafetySetting object from dict: {ss_dict}. Using dict directly.") | |
api_safety_settings.append(ss_dict) # Fallback to dict | |
else: # genai_types not available | |
api_safety_settings = self.safety_settings_list_of_dicts | |
api_generation_config = None | |
if genai_types.GenerateContentConfig: | |
try: | |
api_generation_config = genai_types.GenerateContentConfig( | |
**self.generation_config_dict, | |
safety_settings=api_safety_settings # This should be list of SafetySetting objects or dicts | |
) | |
except TypeError: | |
logging.warning("Could not create GenerateContentConfig object. Using dicts directly for config.") | |
# Fallback: if GenerateContentConfig fails, try to pass dicts (might not be supported by client.models.generate_content's 'config' param) | |
# The user's snippet uses config=types.GenerateContentConfig(...), so this object is important. | |
# If it fails, the call might fail. | |
api_generation_config = self.generation_config_dict # This is not ideal for the 'config' parameter. | |
# The 'config' parameter of client.models.generate_content expects a GenerateContentConfig object. | |
# If we can't create it, we should signal an error or try a different call structure if available. | |
# For now, proceed and let the API call potentially fail if config is malformed. | |
# A better fallback would be to construct the config parts individually if the main object fails. | |
# However, the user's snippet is clear: config=types.GenerateContentConfig(...) | |
# So, if genai_types.GenerateContentConfig is None, this will be an issue. | |
else: # genai_types.GenerateContentConfig is None (likely import error) | |
logging.error("genai_types.GenerateContentConfig not available. Cannot form API config.") | |
return "Error: AI Agent configuration problem (GenerateContentConfig type missing)." | |
# 4. Make the API call (synchronous SDK call wrapped in asyncio.to_thread) | |
try: | |
response = await asyncio.to_thread( | |
client.models.generate_content, # As per user's snippet | |
model=self.llm_model_name, | |
contents=api_call_contents, | |
config=api_generation_config # Pass the GenerateContentConfig object | |
) | |
# Extract text. User's snippet uses response.text | |
# Check for blocked content or other issues | |
if not response.candidates: | |
block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown" | |
logging.warning(f"AI response blocked or empty. Reason: {block_reason}") | |
# You might want to inspect response.prompt_feedback for block reasons | |
error_message = f"The AI's response was blocked. Reason: {block_reason}." | |
if response.prompt_feedback and response.prompt_feedback.block_reason_message: | |
error_message += f" Details: {response.prompt_feedback.block_reason_message}" | |
return error_message | |
answer = response.text.strip() | |
logging.info(f"Successfully received AI response (first 100 chars): {answer[:100]}") | |
except Exception as e: | |
logging.error(f"Error during GenAI call: {e}", exc_info=True) | |
# Check if it's a Google specific API error for more details | |
# from google.api_core import exceptions as google_exceptions | |
# if isinstance(e, google_exceptions.GoogleAPIError): | |
# answer = f"API Error: {e.message}" | |
# else: | |
answer = f"# Error during AI processing:\n{type(e).__name__}: {str(e)}" | |
return answer | |
def clear_chat_history(self): # This method is called by app.py | |
"""Clears the agent's internal chat history.""" | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent chat history cleared by request.") | |
# --- Module-level function for schema display in app.py --- | |
def get_all_schemas_representation(all_dataframes: dict) -> str: | |
""" | |
Generates a string representation of the schemas of all DataFrames, | |
intended for display in the Gradio UI. | |
This is a standalone function as it's imported directly by app.py. | |
""" | |
if not all_dataframes: | |
return "No DataFrames are currently loaded." | |
schema_descriptions = ["DataFrames currently available in the application state:"] | |
for key, df in all_dataframes.items(): | |
df_name = f"df_{key}" | |
columns = ", ".join(df.columns) | |
shape = df.shape | |
if df.empty: | |
schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}" | |
else: | |
# Provide a bit more detail for UI display | |
sample_data_str = df.head(2).to_markdown(index=False) # Use markdown for better UI rendering | |
schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns: {columns}\n\n<details><summary>Sample Data (first 2 rows of {df_name}):</summary>\n\n{sample_data_str}\n\n</details>") | |
schema_descriptions.append(schema) | |
return "\n".join(schema_descriptions) | |