Spaces:
Running
Running
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap # Not used, but kept from original | |
from datetime import datetime # Not used, but kept from original | |
from typing import Dict, List, Optional, Union, Any | |
import traceback | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s') | |
try: | |
from google import genai | |
from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc. | |
from google.genai import errors | |
# If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports. | |
# For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig | |
# and EmbedContentConfig might be implicit or part of task_type. | |
GENAI_AVAILABLE = True | |
logging.info("Google Generative AI library imported successfully.") | |
except ImportError: | |
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
GENAI_AVAILABLE = False | |
# Dummy classes for graceful degradation (simplified) | |
class genai: | |
Client = None | |
# If using google-generativeai, these would be different: | |
# GenerativeModel = None | |
# def configure(*args, **kwargs): pass | |
# def embed_content(*args, **kwargs): return {} | |
class types: # Placeholder for types used in the original code | |
EmbedContentConfig = None # Placeholder | |
GenerationConfig = None # Placeholder | |
SafetySetting = None | |
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) # Dummy for FinishReason | |
class HarmCategory: | |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: | |
BLOCK_NONE = "BLOCK_NONE" | |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" | |
class generation_types: # Dummy for BlockedPromptException | |
BlockedPromptException = type('BlockedPromptException', (Exception,), {}) | |
# --- Custom Exceptions --- | |
class ValidationError(Exception): | |
"""Custom validation error for agent inputs""" | |
pass | |
class RateLimitError(Exception): # Not used, but kept | |
"""Placeholder for rate limit errors.""" | |
pass | |
class AgentNotReadyError(Exception): | |
"""Agent is not properly initialized""" | |
pass | |
# --- Configuration Constants --- | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
LLM_MODEL_NAME = "gemini-1.5-flash-latest" # For google-generativeai, model name is directly used. | |
# For client.models.generate_content, it might need "models/gemini-1.5-flash-latest" | |
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Similarly, might need "models/text-embedding-004" | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_output_tokens": 8192, # Ensure this is supported | |
"candidate_count": 1, | |
} | |
DEFAULT_SAFETY_SETTINGS = [] # User can populate this with {'category': HarmCategory.HARM_CATEGORY_X, 'threshold': HarmBlockThreshold.BLOCK_Y} | |
# Default RAG documents | |
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({ | |
'text': [ | |
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.", | |
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.", | |
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.", | |
"Analyzing follower demographics and post engagement helps refine employer branding strategies.", | |
"Content strategy should align with company values to attract the right talent.", | |
"Employee advocacy programs can significantly boost employer brand reach and authenticity." | |
] | |
}) | |
# --- Client Initialization --- | |
client = None | |
if GEMINI_API_KEY and GENAI_AVAILABLE: | |
try: | |
# This is specific. If using google-generativeai, this would be genai.configure(api_key=...) | |
client = genai.Client(api_key=GEMINI_API_KEY) | |
logging.info("Google GenAI client initialized successfully (using genai.Client).") | |
except Exception as e: | |
logging.error(f"Failed to initialize Google GenAI client (using genai.Client): {e}") | |
client = None | |
else: | |
if not GEMINI_API_KEY: | |
logging.warning("GEMINI_API_KEY environment variable not set.") | |
if not GENAI_AVAILABLE: | |
logging.warning("Google GenAI library not available.") | |
# --- Utility function to get DataFrame schema representation --- | |
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str: | |
"""Generates a string representation of a DataFrame's schema and a small sample.""" | |
if not isinstance(df, pd.DataFrame): | |
return f"Item '{df_name}' is not a DataFrame.\n" | |
if df.empty: | |
return f"DataFrame '{df_name}': Empty\n" | |
schema_parts = [f"DataFrame '{df_name}':"] | |
schema_parts.append(f" Shape: {df.shape}") | |
schema_parts.append(" Columns:") | |
for col in df.columns: | |
col_type = str(df[col].dtype) | |
null_count = df[col].isnull().sum() | |
unique_count = df[col].nunique() | |
schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})") | |
if not df.empty: | |
schema_parts.append(" Sample Data (first 2 rows):") | |
try: | |
sample_df_str = df.head(2).to_string(index=True, max_colwidth=50) # Show index for context | |
indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')]) | |
schema_parts.append(indented_sample_df) | |
except Exception as e: | |
schema_parts.append(f" Could not generate sample data: {e}") | |
return "\n".join(schema_parts) + "\n" | |
def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str: | |
"""Generates a string representation of all DataFrame schemas.""" | |
if not dataframes: | |
return "No DataFrames available to the agent." | |
full_representation = ["=== Available DataFrame Schemas for Analysis ==="] | |
for name, df_instance in dataframes.items(): | |
full_representation.append(get_df_schema_representation(df_instance, name)) | |
return "\n".join(full_representation) | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy() | |
# Ensure 'text' column exists | |
if 'text' not in self.documents_df.columns and not self.documents_df.empty: | |
logging.warning("'text' column not found in RAG documents. RAG might not work.") | |
# Create an empty text column if df is not empty but lacks it, to prevent errors later | |
self.documents_df['text'] = "" | |
self.embedding_model_name = embedding_model_name # e.g., "models/text-embedding-004" or just "text-embedding-004" | |
self.embeddings: Optional[np.ndarray] = None | |
self.is_initialized = False | |
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}") | |
def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]: | |
if not client: | |
raise ConnectionError("GenAI client not initialized for RAG embedding.") | |
if not text or not isinstance(text, str): | |
logging.warning("Cannot embed empty or non-string text for RAG.") | |
return None | |
try: | |
# Standard google-generativeai call: | |
# embedding_response = genai.embed_content( | |
# model=self.embedding_model_name, # e.g., "models/text-embedding-004" | |
# content=text, | |
# task_type="RETRIEVAL_DOCUMENT" # or "SEMANTIC_SIMILARITY" | |
# ) | |
# return np.array(embedding_response['embedding']) | |
# Using the provided client.models.embed_content structure: | |
# This might require specific types for config. | |
embed_config_payload = None | |
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): # Assuming types.EmbedContentConfig is relevant | |
# The task_type for EmbedContentConfig might differ, e.g., "SEMANTIC_SIMILARITY" or "RETRIEVAL_DOCUMENT" | |
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") | |
response = client.models.embed_content( # This is the user's original call structure | |
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name, | |
contents=text, # Original used 'contents', genai.embed_content uses 'content' | |
config=embed_config_payload # Original passed 'config' | |
) | |
# Adapt response parsing based on actual client.models.embed_content behavior | |
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0: | |
# This structure `response.embeddings[0]` seems specific. | |
# Standard genai.embed_content returns a dict `{'embedding': [values]}` | |
return np.array(response.embeddings[0]) | |
elif hasattr(response, 'embedding'): # Common for genai.embed_content | |
return np.array(response.embedding) | |
else: | |
logging.error(f"Unexpected embedding response format: {response}") | |
return None | |
except Exception as e: | |
logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True) | |
raise | |
async def initialize_embeddings(self): | |
if self.documents_df.empty or 'text' not in self.documents_df.columns: | |
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.") | |
self.embeddings = np.array([]) | |
self.is_initialized = True # Initialized, but with no embeddings | |
return | |
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): # Check if standard genai can be used | |
logging.error("GenAI client not available for RAG embedding initialization.") | |
self.embeddings = np.array([]) | |
return | |
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...") | |
embedded_docs_list = [] | |
for index, row in self.documents_df.iterrows(): | |
text_to_embed = row.get('text', '') | |
if not text_to_embed or not isinstance(text_to_embed, str): | |
logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.") | |
continue | |
try: | |
# Use asyncio.to_thread for the synchronous embedding call | |
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed) | |
if embedding_array is not None and embedding_array.size > 0: | |
embedded_docs_list.append(embedding_array) | |
else: | |
logging.warning(f"Empty or failed embedding for RAG document at index {index}.") | |
except Exception as e: | |
logging.error(f"Error embedding RAG document at index {index}: {e}") | |
continue # Continue with other documents | |
if not embedded_docs_list: | |
self.embeddings = np.array([]) | |
logging.warning("No RAG documents were successfully embedded.") | |
else: | |
try: | |
# Ensure all embeddings have the same shape before vstack | |
first_shape = embedded_docs_list[0].shape | |
if not all(emb.shape == first_shape for emb in embedded_docs_list): | |
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.") | |
# Attempt to filter out malformed embeddings if possible, or fail | |
# For now, we'll fail stacking if shapes are inconsistent. | |
self.embeddings = np.array([]) | |
return # Exit if shapes are inconsistent | |
self.embeddings = np.vstack(embedded_docs_list) | |
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}") | |
except ValueError as ve: | |
logging.error(f"Error stacking embeddings (likely due to inconsistent shapes): {ve}") | |
self.embeddings = np.array([]) | |
self.is_initialized = True | |
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray: | |
if embeddings_matrix.ndim == 1: # Handle case of single document embedding | |
embeddings_matrix = embeddings_matrix.reshape(1, -1) | |
if query_vector.ndim == 1: | |
query_vector = query_vector.reshape(1, -1) | |
if embeddings_matrix.size == 0 or query_vector.size == 0: | |
return np.array([]) | |
# Normalize embeddings_matrix rows | |
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True) | |
# Add a small epsilon to avoid division by zero for zero vectors | |
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0) | |
# Normalize query_vector | |
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True) | |
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0) | |
# Calculate dot product | |
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten() | |
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str: | |
if not self.is_initialized: | |
logging.debug("RAG system not initialized. Cannot retrieve info.") | |
return "" | |
if self.embeddings is None or self.embeddings.size == 0: | |
logging.debug("RAG embeddings not available. Cannot retrieve info.") | |
return "" | |
if not query or not isinstance(query, str): | |
logging.debug("Empty or invalid query for RAG retrieval.") | |
return "" | |
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): | |
logging.error("GenAI client not available for RAG query embedding.") | |
return "" | |
try: | |
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) # Embed query | |
if query_vector is None or query_vector.size == 0: | |
logging.warning("Query vector embedding failed or is empty for RAG.") | |
return "" | |
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector) | |
if similarity_scores.size == 0: | |
return "" | |
relevant_indices = np.where(similarity_scores >= min_similarity)[0] | |
if len(relevant_indices) == 0: | |
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'") | |
return "" | |
# Get scores for relevant documents and sort | |
relevant_scores = similarity_scores[relevant_indices] | |
# Argsort returns indices to sort relevant_scores; apply to relevant_indices | |
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]] | |
top_indices = sorted_relevant_indices_of_original[:top_k] | |
context_parts = [] | |
if 'text' in self.documents_df.columns: | |
for i in top_indices: | |
if 0 <= i < len(self.documents_df): | |
context_parts.append(self.documents_df.iloc[i]['text']) | |
context = "\n\n---\n\n".join(context_parts) | |
logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'") | |
return context | |
except Exception as e: | |
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True) | |
return "" | |
class EmployerBrandingAgent: | |
def __init__(self, | |
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
rag_documents_df: Optional[pd.DataFrame] = None, | |
llm_model_name: str = LLM_MODEL_NAME, | |
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME, | |
generation_config_dict: Optional[Dict] = None, | |
safety_settings_list: Optional[List] = None): # safety_settings_list expects list of dicts or SafetySetting objects | |
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} # Deep copy | |
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy() | |
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name) | |
self.llm_model_name = llm_model_name | |
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS | |
# Ensure safety settings are in the correct format if using google-generativeai directly | |
self.safety_settings_list = [] | |
if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'): | |
for ss_dict in safety_settings_list: | |
try: | |
# Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC} | |
self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold'])) | |
except Exception as e: | |
logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}") | |
elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models) | |
self.safety_settings_list = safety_settings_list | |
self.chat_history: List[Dict[str, str]] = [] # Stores {"role": "user/model", "content": "..."} | |
self.is_ready = False | |
self.llm_model_instance = None # For google-generativeai | |
if GENAI_AVAILABLE and client is None and GEMINI_API_KEY: # If client.Client failed but standard genai can be used | |
try: | |
genai.configure(api_key=GEMINI_API_KEY) | |
self.llm_model_instance = genai.GenerativeModel(self.llm_model_name) | |
logging.info(f"Initialized GenerativeModel '{self.llm_model_name}' via google-generativeai.") | |
except Exception as e: | |
logging.error(f"Failed to initialize google-generativeai.GenerativeModel: {e}") | |
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}") | |
async def initialize(self) -> bool: | |
"""Initializes asynchronous components of the agent, primarily RAG embeddings.""" | |
try: | |
if not client and not self.llm_model_instance : # Check if any LLM access is configured | |
logging.error("Cannot initialize agent: GenAI client (client.Client or google.generativeai) not available/configured.") | |
return False | |
await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized | |
self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs) | |
logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}") | |
return True | |
except Exception as e: | |
logging.error(f"Error during EmployerBrandingAgent.initialize: {e}", exc_info=True) | |
self.is_ready = False | |
return False | |
def _get_dataframes_summary(self) -> str: | |
return get_all_schemas_representation(self.all_dataframes) | |
def _build_system_prompt(self) -> str: | |
# This prompt provides overall guidance to the LLM. | |
return textwrap.dedent(""" | |
You are an expert Employer Branding Analyst AI. Your primary function is to analyze LinkedIn data provided (follower statistics, post performance, mentions) and offer actionable insights, data-driven recommendations, and if requested, Python Pandas code snippets for further analysis. | |
When providing insights or recommendations: | |
- Be specific and base your conclusions on the data summaries and context provided. | |
- Structure responses clearly, perhaps using bullet points for key findings or actions. | |
- Focus on practical advice that can help improve employer branding efforts. | |
When asked to generate Pandas code: | |
- Assume the data is available in pandas DataFrames named exactly as in the 'Available DataFrame Schemas' section (e.g., `df_follower_stats`, `df_posts`). | |
- Generate executable Python code using pandas. | |
- Ensure the code is directly relevant to the user's query and the available data. | |
- Briefly explain what the code does. | |
- If a query implies data not present in the schemas, state that and do not attempt to fabricate code for it. | |
- Do not generate code that modifies DataFrames in place unless explicitly asked. Prefer returning new DataFrames or Series. | |
- Handle potential errors in data (e.g., missing values if relevant to the operation) gracefully if simple to do so. | |
- Output the code in a single, copy-pasteable block. | |
Always refer to the provided DataFrame schemas to understand available columns and data types. Do not hallucinate columns or data. | |
If a query is ambiguous or requires data not present, ask for clarification or state the limitation. | |
""").strip() | |
async def _generate_response(self, current_user_query: str) -> str: | |
""" | |
Generates a response from the LLM based on the current query, system prompts, | |
data summaries, RAG context, and the agent's chat history. | |
Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry. | |
""" | |
if not self.is_ready: | |
return "Agent is not ready. Please initialize." | |
if not client and not self.llm_model_instance: | |
return "Error: AI service is not available. Check API configuration." | |
try: | |
system_prompt_text = self._build_system_prompt() | |
data_summary_text = self._get_dataframes_summary() | |
rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25) | |
# Construct the messages for the LLM API call | |
llm_messages = [] | |
# 1. System-level instructions and context (as a first "user" turn) | |
initial_context_prompt = ( | |
f"{system_prompt_text}\n\n" | |
f"## Available Data Overview:\n{data_summary_text}\n\n" | |
f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n" | |
f"Given this context, please respond to the user queries that follow in the chat history." | |
) | |
llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]}) | |
# 2. Priming assistant message | |
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]}) | |
# 3. Append the actual conversation history (already includes the current user query) | |
for entry in self.chat_history: | |
llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]}) | |
# --- Make the API call --- | |
response_text = "" | |
if self.llm_model_instance: # Standard google-generativeai usage | |
logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}") | |
# Prepare generation config and safety settings for google-generativeai | |
gen_config_payload = self.generation_config_dict | |
safety_settings_payload = self.safety_settings_list | |
if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig): | |
try: | |
gen_config_payload = types.GenerationConfig(**self.generation_config_dict) | |
except Exception as e: | |
logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}") | |
api_response = await self.llm_model_instance.generate_content_async( | |
contents=llm_messages, | |
generation_config=gen_config_payload, | |
safety_settings=safety_settings_payload | |
) | |
response_text = api_response.text | |
elif client: # google.genai client usage | |
logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}") | |
# Convert messages to the format expected by google.genai client | |
# The client expects a simpler contents format | |
contents = [] | |
for msg in llm_messages: | |
if msg["role"] == "user": | |
contents.append(msg["parts"][0]["text"]) | |
elif msg["role"] == "model": | |
# For model responses, we might need to handle differently | |
# but for now, let's include them as context | |
contents.append(f"Assistant: {msg['parts'][0]['text']}") | |
# Create the config object with both generation config and safety settings | |
config_dict = {} | |
# Add generation config parameters | |
if self.generation_config_dict: | |
for key, value in self.generation_config_dict.items(): | |
config_dict[key] = value | |
# Add safety settings | |
if self.safety_settings_list: | |
# Convert safety settings to the correct format if needed | |
safety_settings = [] | |
for ss in self.safety_settings_list: | |
if isinstance(ss, dict): | |
# Convert dict to types.SafetySetting | |
safety_settings.append(types.SafetySetting( | |
category=ss.get('category'), | |
threshold=ss.get('threshold') | |
)) | |
else: | |
safety_settings.append(ss) | |
config_dict['safety_settings'] = safety_settings | |
# Create the config object | |
config = types.GenerateContentConfig(**config_dict) | |
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name | |
api_response = await asyncio.to_thread( | |
client.models.generate_content, | |
model=model_path, | |
contents=contents, # Simplified contents format | |
config=config # Using config parameter instead of separate generation_config and safety_settings | |
) | |
# Parse response from client.models structure | |
if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts: | |
response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')] | |
response_text = "".join(response_text_parts).strip() | |
else: | |
# Handle blocked or empty responses | |
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason: | |
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}") | |
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}" | |
if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'): | |
finish_reason = api_response.candidates[0].finish_reason | |
if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP: | |
logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}") | |
return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing." | |
return "I apologize, but I couldn't generate a response from client.models." | |
else: | |
raise ConnectionError("No valid LLM client or model instance available.") | |
return response_text.strip() | |
except Exception as e: | |
error_message = str(e).lower() | |
# Check if it's a blocked prompt error by examining the error message | |
if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']): | |
logging.error(f"Blocked prompt from LLM: {e}", exc_info=True) | |
return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}" | |
else: | |
logging.error(f"Error in _generate_response: {e}", exc_info=True) | |
return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}" | |
def _validate_query(self, query: str) -> bool: | |
if not query or not isinstance(query, str) or len(query.strip()) < 3: | |
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'") | |
return False | |
if len(query) > 3000: # Increased limit slightly | |
logging.warning(f"Invalid query: too long. Length: {len(query)}") | |
return False | |
return True | |
async def process_query(self, user_query: str) -> str: | |
""" | |
Processes the user's query. | |
It relies on self.chat_history being set externally (by app.py) to include the full | |
conversation context, including the current user_query as the last "user" message. | |
This method then calls _generate_response to get the AI's reply. | |
It does NOT modify self.chat_history itself; app.py is responsible for that based on Gradio state. | |
""" | |
if not self._validate_query(user_query): | |
# This user_query is the one from Gradio input, also the last one in self.chat_history | |
return "Please provide a valid query (3 to 3000 characters)." | |
if not self.is_ready: | |
logging.warning("process_query called but agent is not ready. Attempting re-initialization.") | |
# This is a fallback. Ideally, initialize is called once and confirmed. | |
init_success = await self.initialize() | |
if not init_success: | |
return "The agent is not properly initialized and could not be started. Please check configuration and logs." | |
# user_query is the current text from the input box. | |
# self.chat_history (set by app.py) should already contain this user_query as the last message. | |
# We pass user_query to _generate_response primarily for RAG context retrieval for the current turn. | |
response_text = await self._generate_response(user_query) | |
return response_text | |
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]): | |
"""Updates the agent's DataFrames. Does not automatically re-initialize RAG or LLM.""" | |
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} # Deep copy | |
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}") | |
# Note: If RAG documents depend on these DataFrames, RAG might need re-initialization. | |
# For now, RAG uses a static document set. | |
def clear_chat_history(self): | |
"""Clears the agent's internal chat history. App.py should also clear Gradio state.""" | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent internal chat history cleared.") | |
def get_status(self) -> Dict[str, Any]: | |
return { | |
"is_ready": self.is_ready, | |
"has_api_key": bool(GEMINI_API_KEY), | |
"genai_available": GENAI_AVAILABLE, | |
"client_type": "genai.Client" if client else ("google-generativeai" if self.llm_model_instance else "None"), | |
"rag_initialized": self.rag_system.is_initialized, | |
"num_dataframes": len(self.all_dataframes), | |
"dataframe_keys": list(self.all_dataframes.keys()), | |
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0, | |
"llm_model_name": self.llm_model_name, | |
"embedding_model_name": self.embedding_model_name | |
} | |
# --- Functions for Gradio integration (if needed directly, but app.py handles instantiation) --- | |
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None, | |
rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent: | |
logging.info("Creating new EmployerBrandingAgent instance via helper function.") | |
return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs) | |
async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool: | |
logging.info("Initializing agent via async helper function.") | |
return await agent.initialize() | |
if __name__ == "__main__": | |
async def test_agent_logic(): | |
print("--- Testing Employer Branding Agent ---") | |
if not GEMINI_API_KEY: | |
print("GEMINI_API_KEY not set. Skipping live API tests.") | |
return | |
sample_dfs = { | |
"followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}), | |
"posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]}) | |
} | |
# Test RAG document loading | |
custom_rag = pd.DataFrame({'text': ["Custom RAG context about LinkedIn engagement."]}) | |
agent = EmployerBrandingAgent( | |
all_dataframes=sample_dfs, | |
rag_documents_df=custom_rag, | |
llm_model_name=LLM_MODEL_NAME, | |
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME | |
) | |
print("Agent Status (pre-init):", agent.get_status()) | |
init_success = await agent.initialize() | |
print(f"Agent Initialization Success: {init_success}") | |
print("Agent Status (post-init):", agent.get_status()) | |
if not init_success: | |
print("Agent initialization failed. Cannot proceed with query test.") | |
return | |
# Simulate app.py setting history | |
test_query1 = "What are the key columns in my followers data?" | |
agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this | |
print(f"\nProcessing Query 1: '{test_query1}'") | |
response1 = await agent.process_query(user_query=test_query1) # Pass current query for RAG etc. | |
print(f"Agent Response 1:\n{response1}") | |
# Simulate app.py updating history for next turn | |
agent.chat_history.append({"role": "model", "content": response1}) | |
test_query2 = "Generate pandas code to get the total follower count." | |
agent.chat_history.append({"role": "user", "content": test_query2}) | |
print(f"\nProcessing Query 2: '{test_query2}'") | |
response2 = await agent.process_query(user_query=test_query2) | |
print(f"Agent Response 2:\n{response2}") | |
agent.chat_history.append({"role": "model", "content": response2}) | |
print("\nFinal Agent Chat History (internal):") | |
for item in agent.chat_history: | |
print(f"- {item['role']}: {item['content'][:100]}...") | |
print("\n--- Test Complete ---") | |
asyncio.run(test_agent_logic()) | |