LinkedinMonitor / eb_agent_module.py
GuglielmoTor's picture
Update eb_agent_module.py
514ad52 verified
raw
history blame
35.7 kB
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap # Not used, but kept from original
from datetime import datetime # Not used, but kept from original
from typing import Dict, List, Optional, Union, Any
import traceback
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
try:
from google import genai
from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc.
# If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
# For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
# and EmbedContentConfig might be implicit or part of task_type.
GENAI_AVAILABLE = True
logging.info("Google Generative AI library imported successfully.")
except ImportError:
logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai")
GENAI_AVAILABLE = False
# Dummy classes for graceful degradation (simplified)
class genai:
Client = None
# If using google-generativeai, these would be different:
# GenerativeModel = None
# def configure(*args, **kwargs): pass
# def embed_content(*args, **kwargs): return {}
class types: # Placeholder for types used in the original code
EmbedContentConfig = None # Placeholder
GenerationConfig = None # Placeholder
SafetySetting = None
Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) # Dummy for FinishReason
class HarmCategory:
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class HarmBlockThreshold:
BLOCK_NONE = "BLOCK_NONE"
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
class generation_types: # Dummy for BlockedPromptException
BlockedPromptException = type('BlockedPromptException', (Exception,), {})
# --- Custom Exceptions ---
class ValidationError(Exception):
"""Custom validation error for agent inputs"""
pass
class RateLimitError(Exception): # Not used, but kept
"""Placeholder for rate limit errors."""
pass
class AgentNotReadyError(Exception):
"""Agent is not properly initialized"""
pass
# --- Configuration Constants ---
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
LLM_MODEL_NAME = "gemini-1.5-flash-latest" # For google-generativeai, model name is directly used.
# For client.models.generate_content, it might need "models/gemini-1.5-flash-latest"
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Similarly, might need "models/text-embedding-004"
GENERATION_CONFIG_PARAMS = {
"temperature": 0.7,
"top_p": 0.95,
"top_k": 40,
"max_output_tokens": 8192, # Ensure this is supported
"candidate_count": 1,
}
DEFAULT_SAFETY_SETTINGS = [] # User can populate this with {'category': HarmCategory.HARM_CATEGORY_X, 'threshold': HarmBlockThreshold.BLOCK_Y}
# Default RAG documents
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
'text': [
"Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
"Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
"LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
"Analyzing follower demographics and post engagement helps refine employer branding strategies.",
"Content strategy should align with company values to attract the right talent.",
"Employee advocacy programs can significantly boost employer brand reach and authenticity."
]
})
# --- Client Initialization ---
client = None
if GEMINI_API_KEY and GENAI_AVAILABLE:
try:
# This is specific. If using google-generativeai, this would be genai.configure(api_key=...)
client = genai.Client(api_key=GEMINI_API_KEY)
logging.info("Google GenAI client initialized successfully (using genai.Client).")
except Exception as e:
logging.error(f"Failed to initialize Google GenAI client (using genai.Client): {e}")
client = None
else:
if not GEMINI_API_KEY:
logging.warning("GEMINI_API_KEY environment variable not set.")
if not GENAI_AVAILABLE:
logging.warning("Google GenAI library not available.")
# --- Utility function to get DataFrame schema representation ---
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str:
"""Generates a string representation of a DataFrame's schema and a small sample."""
if not isinstance(df, pd.DataFrame):
return f"Item '{df_name}' is not a DataFrame.\n"
if df.empty:
return f"DataFrame '{df_name}': Empty\n"
schema_parts = [f"DataFrame '{df_name}':"]
schema_parts.append(f" Shape: {df.shape}")
schema_parts.append(" Columns:")
for col in df.columns:
col_type = str(df[col].dtype)
null_count = df[col].isnull().sum()
unique_count = df[col].nunique()
schema_parts.append(f" - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})")
if not df.empty:
schema_parts.append(" Sample Data (first 2 rows):")
try:
sample_df_str = df.head(2).to_string(index=True, max_colwidth=50) # Show index for context
indented_sample_df = "\n".join([" " + line for line in sample_df_str.split('\n')])
schema_parts.append(indented_sample_df)
except Exception as e:
schema_parts.append(f" Could not generate sample data: {e}")
return "\n".join(schema_parts) + "\n"
def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str:
"""Generates a string representation of all DataFrame schemas."""
if not dataframes:
return "No DataFrames available to the agent."
full_representation = ["=== Available DataFrame Schemas for Analysis ==="]
for name, df_instance in dataframes.items():
full_representation.append(get_df_schema_representation(df_instance, name))
return "\n".join(full_representation)
class AdvancedRAGSystem:
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy()
# Ensure 'text' column exists
if 'text' not in self.documents_df.columns and not self.documents_df.empty:
logging.warning("'text' column not found in RAG documents. RAG might not work.")
# Create an empty text column if df is not empty but lacks it, to prevent errors later
self.documents_df['text'] = ""
self.embedding_model_name = embedding_model_name # e.g., "models/text-embedding-004" or just "text-embedding-004"
self.embeddings: Optional[np.ndarray] = None
self.is_initialized = False
logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")
def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]:
if not client:
raise ConnectionError("GenAI client not initialized for RAG embedding.")
if not text or not isinstance(text, str):
logging.warning("Cannot embed empty or non-string text for RAG.")
return None
try:
# Standard google-generativeai call:
# embedding_response = genai.embed_content(
# model=self.embedding_model_name, # e.g., "models/text-embedding-004"
# content=text,
# task_type="RETRIEVAL_DOCUMENT" # or "SEMANTIC_SIMILARITY"
# )
# return np.array(embedding_response['embedding'])
# Using the provided client.models.embed_content structure:
# This might require specific types for config.
embed_config_payload = None
if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): # Assuming types.EmbedContentConfig is relevant
# The task_type for EmbedContentConfig might differ, e.g., "SEMANTIC_SIMILARITY" or "RETRIEVAL_DOCUMENT"
embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
response = client.models.embed_content( # This is the user's original call structure
model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
contents=text, # Original used 'contents', genai.embed_content uses 'content'
config=embed_config_payload # Original passed 'config'
)
# Adapt response parsing based on actual client.models.embed_content behavior
if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
# This structure `response.embeddings[0]` seems specific.
# Standard genai.embed_content returns a dict `{'embedding': [values]}`
return np.array(response.embeddings[0])
elif hasattr(response, 'embedding'): # Common for genai.embed_content
return np.array(response.embedding)
else:
logging.error(f"Unexpected embedding response format: {response}")
return None
except Exception as e:
logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True)
raise
async def initialize_embeddings(self):
if self.documents_df.empty or 'text' not in self.documents_df.columns:
logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
self.embeddings = np.array([])
self.is_initialized = True # Initialized, but with no embeddings
return
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): # Check if standard genai can be used
logging.error("GenAI client not available for RAG embedding initialization.")
self.embeddings = np.array([])
return
logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
embedded_docs_list = []
for index, row in self.documents_df.iterrows():
text_to_embed = row.get('text', '')
if not text_to_embed or not isinstance(text_to_embed, str):
logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.")
continue
try:
# Use asyncio.to_thread for the synchronous embedding call
embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
if embedding_array is not None and embedding_array.size > 0:
embedded_docs_list.append(embedding_array)
else:
logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
except Exception as e:
logging.error(f"Error embedding RAG document at index {index}: {e}")
continue # Continue with other documents
if not embedded_docs_list:
self.embeddings = np.array([])
logging.warning("No RAG documents were successfully embedded.")
else:
try:
# Ensure all embeddings have the same shape before vstack
first_shape = embedded_docs_list[0].shape
if not all(emb.shape == first_shape for emb in embedded_docs_list):
logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
# Attempt to filter out malformed embeddings if possible, or fail
# For now, we'll fail stacking if shapes are inconsistent.
self.embeddings = np.array([])
return # Exit if shapes are inconsistent
self.embeddings = np.vstack(embedded_docs_list)
logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
except ValueError as ve:
logging.error(f"Error stacking embeddings (likely due to inconsistent shapes): {ve}")
self.embeddings = np.array([])
self.is_initialized = True
def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
if embeddings_matrix.ndim == 1: # Handle case of single document embedding
embeddings_matrix = embeddings_matrix.reshape(1, -1)
if query_vector.ndim == 1:
query_vector = query_vector.reshape(1, -1)
if embeddings_matrix.size == 0 or query_vector.size == 0:
return np.array([])
# Normalize embeddings_matrix rows
norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
# Add a small epsilon to avoid division by zero for zero vectors
normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)
# Normalize query_vector
norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
# Calculate dot product
return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()
async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
if not self.is_initialized:
logging.debug("RAG system not initialized. Cannot retrieve info.")
return ""
if self.embeddings is None or self.embeddings.size == 0:
logging.debug("RAG embeddings not available. Cannot retrieve info.")
return ""
if not query or not isinstance(query, str):
logging.debug("Empty or invalid query for RAG retrieval.")
return ""
if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
logging.error("GenAI client not available for RAG query embedding.")
return ""
try:
query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) # Embed query
if query_vector is None or query_vector.size == 0:
logging.warning("Query vector embedding failed or is empty for RAG.")
return ""
similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
if similarity_scores.size == 0:
return ""
relevant_indices = np.where(similarity_scores >= min_similarity)[0]
if len(relevant_indices) == 0:
logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
return ""
# Get scores for relevant documents and sort
relevant_scores = similarity_scores[relevant_indices]
# Argsort returns indices to sort relevant_scores; apply to relevant_indices
sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
top_indices = sorted_relevant_indices_of_original[:top_k]
context_parts = []
if 'text' in self.documents_df.columns:
for i in top_indices:
if 0 <= i < len(self.documents_df):
context_parts.append(self.documents_df.iloc[i]['text'])
context = "\n\n---\n\n".join(context_parts)
logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'")
return context
except Exception as e:
logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
return ""
class EmployerBrandingAgent:
def __init__(self,
all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
rag_documents_df: Optional[pd.DataFrame] = None,
llm_model_name: str = LLM_MODEL_NAME,
embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
generation_config_dict: Optional[Dict] = None,
safety_settings_list: Optional[List] = None): # safety_settings_list expects list of dicts or SafetySetting objects
self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} # Deep copy
_rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
self.llm_model_name = llm_model_name
self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
# Ensure safety settings are in the correct format if using google-generativeai directly
self.safety_settings_list = []
if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'):
for ss_dict in safety_settings_list:
try:
# Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC}
self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
except Exception as e:
logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}")
elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models)
self.safety_settings_list = safety_settings_list
self.chat_history: List[Dict[str, str]] = [] # Stores {"role": "user/model", "content": "..."}
self.is_ready = False
self.llm_model_instance = None # For google-generativeai
if GENAI_AVAILABLE and client is None and GEMINI_API_KEY: # If client.Client failed but standard genai can be used
try:
genai.configure(api_key=GEMINI_API_KEY)
self.llm_model_instance = genai.GenerativeModel(self.llm_model_name)
logging.info(f"Initialized GenerativeModel '{self.llm_model_name}' via google-generativeai.")
except Exception as e:
logging.error(f"Failed to initialize google-generativeai.GenerativeModel: {e}")
logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")
async def initialize(self) -> bool:
"""Initializes asynchronous components of the agent, primarily RAG embeddings."""
try:
if not client and not self.llm_model_instance : # Check if any LLM access is configured
logging.error("Cannot initialize agent: GenAI client (client.Client or google.generativeai) not available/configured.")
return False
await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized
self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs)
logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}")
return True
except Exception as e:
logging.error(f"Error during EmployerBrandingAgent.initialize: {e}", exc_info=True)
self.is_ready = False
return False
def _get_dataframes_summary(self) -> str:
return get_all_schemas_representation(self.all_dataframes)
def _build_system_prompt(self) -> str:
# This prompt provides overall guidance to the LLM.
return textwrap.dedent("""
You are an expert Employer Branding Analyst AI. Your primary function is to analyze LinkedIn data provided (follower statistics, post performance, mentions) and offer actionable insights, data-driven recommendations, and if requested, Python Pandas code snippets for further analysis.
When providing insights or recommendations:
- Be specific and base your conclusions on the data summaries and context provided.
- Structure responses clearly, perhaps using bullet points for key findings or actions.
- Focus on practical advice that can help improve employer branding efforts.
When asked to generate Pandas code:
- Assume the data is available in pandas DataFrames named exactly as in the 'Available DataFrame Schemas' section (e.g., `df_follower_stats`, `df_posts`).
- Generate executable Python code using pandas.
- Ensure the code is directly relevant to the user's query and the available data.
- Briefly explain what the code does.
- If a query implies data not present in the schemas, state that and do not attempt to fabricate code for it.
- Do not generate code that modifies DataFrames in place unless explicitly asked. Prefer returning new DataFrames or Series.
- Handle potential errors in data (e.g., missing values if relevant to the operation) gracefully if simple to do so.
- Output the code in a single, copy-pasteable block.
Always refer to the provided DataFrame schemas to understand available columns and data types. Do not hallucinate columns or data.
If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
""").strip()
async def _generate_response(self, current_user_query: str) -> str:
"""
Generates a response from the LLM based on the current query, system prompts,
data summaries, RAG context, and the agent's chat history.
Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
"""
if not self.is_ready:
return "Agent is not ready. Please initialize."
if not client and not self.llm_model_instance:
return "Error: AI service is not available. Check API configuration."
try:
system_prompt_text = self._build_system_prompt()
data_summary_text = self._get_dataframes_summary()
rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25) # Fine-tuned RAG params
# Construct the messages for the LLM API call
# The history (self.chat_history) is set by app.py and includes the current user query.
llm_messages = []
# 1. System-level instructions and context (as a first "user" turn)
initial_context_prompt = (
f"{system_prompt_text}\n\n"
f"## Available Data Overview:\n{data_summary_text}\n\n"
f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
f"Given this context, please respond to the user queries that follow in the chat history."
)
llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
# 2. Priming assistant message
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
# 3. Append the actual conversation history (already includes the current user query)
for entry in self.chat_history: # self.chat_history is set by app.py
llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
# Prepare generation config and safety settings for the API
gen_config_payload = self.generation_config_dict
safety_settings_payload = self.safety_settings_list # Already formatted if types.SafetySetting used
if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
try:
gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
except Exception as e:
logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
# --- Make the API call ---
response_text = ""
if self.llm_model_instance: # Standard google-generativeai usage
logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
api_response = await self.llm_model_instance.generate_content_async(
contents=llm_messages,
generation_config=gen_config_payload,
safety_settings=safety_settings_payload
)
response_text = api_response.text # Simplification, assumes single part text response
elif client: # User's original client.models.generate_content structure
logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
# This call needs to be async or wrapped, asyncio.to_thread is used as in original
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
api_response = await asyncio.to_thread(
client.models.generate_content,
model=model_path,
contents=llm_messages,
generation_config=gen_config_payload, # Ensure this is the correct type for client.models
safety_settings=safety_settings_payload # Ensure this is the correct type
)
# Parse response from client.models structure
if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
response_text = "".join(response_text_parts).strip()
else: # Handle blocked or empty responses from client.models
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
if api_response.candidates and api_response.candidates[0].finish_reason != types.Candidate.FinishReason.STOP: # Assuming types.Candidate.FinishReason.STOP is valid
logging.warning(f"Content generation stopped by client.models due to: {api_response.candidates[0].finish_reason}. Safety: {api_response.candidates[0].safety_ratings if hasattr(api_response.candidates[0], 'safety_ratings') else 'N/A'}")
return f"I couldn't complete the response. Reason: {api_response.candidates[0].finish_reason}. Please try rephrasing."
return "I apologize, but I couldn't generate a response from client.models."
else:
raise ConnectionError("No valid LLM client or model instance available.")
return response_text.strip()
except types.generation_types.BlockedPromptException as bpe: # Specific exception for google-generativeai
logging.error(f"BlockedPromptException from LLM: {bpe}", exc_info=True)
return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {bpe}"
except Exception as e:
logging.error(f"Error in _generate_response: {e}", exc_info=True)
return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
def _validate_query(self, query: str) -> bool:
if not query or not isinstance(query, str) or len(query.strip()) < 3:
logging.warning(f"Invalid query: too short or not a string. Query: '{query}'")
return False
if len(query) > 3000: # Increased limit slightly
logging.warning(f"Invalid query: too long. Length: {len(query)}")
return False
return True
async def process_query(self, user_query: str) -> str:
"""
Processes the user's query.
It relies on self.chat_history being set externally (by app.py) to include the full
conversation context, including the current user_query as the last "user" message.
This method then calls _generate_response to get the AI's reply.
It does NOT modify self.chat_history itself; app.py is responsible for that based on Gradio state.
"""
if not self._validate_query(user_query):
# This user_query is the one from Gradio input, also the last one in self.chat_history
return "Please provide a valid query (3 to 3000 characters)."
if not self.is_ready:
logging.warning("process_query called but agent is not ready. Attempting re-initialization.")
# This is a fallback. Ideally, initialize is called once and confirmed.
init_success = await self.initialize()
if not init_success:
return "The agent is not properly initialized and could not be started. Please check configuration and logs."
# user_query is the current text from the input box.
# self.chat_history (set by app.py) should already contain this user_query as the last message.
# We pass user_query to _generate_response primarily for RAG context retrieval for the current turn.
response_text = await self._generate_response(user_query)
return response_text
def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]):
"""Updates the agent's DataFrames. Does not automatically re-initialize RAG or LLM."""
self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} # Deep copy
logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}")
# Note: If RAG documents depend on these DataFrames, RAG might need re-initialization.
# For now, RAG uses a static document set.
def clear_chat_history(self):
"""Clears the agent's internal chat history. App.py should also clear Gradio state."""
self.chat_history = []
logging.info("EmployerBrandingAgent internal chat history cleared.")
def get_status(self) -> Dict[str, Any]:
return {
"is_ready": self.is_ready,
"has_api_key": bool(GEMINI_API_KEY),
"genai_available": GENAI_AVAILABLE,
"client_type": "genai.Client" if client else ("google-generativeai" if self.llm_model_instance else "None"),
"rag_initialized": self.rag_system.is_initialized,
"num_dataframes": len(self.all_dataframes),
"dataframe_keys": list(self.all_dataframes.keys()),
"num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
"llm_model_name": self.llm_model_name,
"embedding_model_name": self.embedding_model_name
}
# --- Functions for Gradio integration (if needed directly, but app.py handles instantiation) ---
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent:
logging.info("Creating new EmployerBrandingAgent instance via helper function.")
return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs)
async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool:
logging.info("Initializing agent via async helper function.")
return await agent.initialize()
if __name__ == "__main__":
async def test_agent_logic():
print("--- Testing Employer Branding Agent ---")
if not GEMINI_API_KEY:
print("GEMINI_API_KEY not set. Skipping live API tests.")
return
sample_dfs = {
"followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}),
"posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]})
}
# Test RAG document loading
custom_rag = pd.DataFrame({'text': ["Custom RAG context about LinkedIn engagement."]})
agent = EmployerBrandingAgent(
all_dataframes=sample_dfs,
rag_documents_df=custom_rag,
llm_model_name=LLM_MODEL_NAME,
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME
)
print("Agent Status (pre-init):", agent.get_status())
init_success = await agent.initialize()
print(f"Agent Initialization Success: {init_success}")
print("Agent Status (post-init):", agent.get_status())
if not init_success:
print("Agent initialization failed. Cannot proceed with query test.")
return
# Simulate app.py setting history
test_query1 = "What are the key columns in my followers data?"
agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this
print(f"\nProcessing Query 1: '{test_query1}'")
response1 = await agent.process_query(user_query=test_query1) # Pass current query for RAG etc.
print(f"Agent Response 1:\n{response1}")
# Simulate app.py updating history for next turn
agent.chat_history.append({"role": "model", "content": response1})
test_query2 = "Generate pandas code to get the total follower count."
agent.chat_history.append({"role": "user", "content": test_query2})
print(f"\nProcessing Query 2: '{test_query2}'")
response2 = await agent.process_query(user_query=test_query2)
print(f"Agent Response 2:\n{response2}")
agent.chat_history.append({"role": "model", "content": response2})
print("\nFinal Agent Chat History (internal):")
for item in agent.chat_history:
print(f"- {item['role']}: {item['content'][:100]}...")
print("\n--- Test Complete ---")
asyncio.run(test_agent_logic())