LinkedinMonitor / eb_agent_module.py
GuglielmoTor's picture
Create eb_agent_module.py
e03d275 verified
raw
history blame
34.2 kB
# eb_agent_module.py
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap
# Attempt to import Google Generative AI and related types
try:
from google import genai
from google.genai import types as genai_types
# from google.api_core import retry_async # For async retries if needed
except ImportError:
print("Google Generative AI library not found. Please install it: pip install google-generativeai")
# Define dummy classes/functions if the import fails, to allow the rest of the script to be parsed
class genai: # type: ignore
@staticmethod
def configure(api_key): pass
@staticmethod
def GenerativeModel(model_name): return None # type: ignore
@staticmethod
def embed_content(model, content, task_type, title=None): return {"embedding": [0.1] * 768} # type: ignore
class genai_types: # type: ignore
@staticmethod
def GenerateContentConfig(**kwargs): return None # type: ignore
class BlockReason: # type: ignore
SAFETY = "SAFETY"
class HarmCategory: # type: ignore
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class HarmBlockThreshold:
BLOCK_NONE = "BLOCK_NONE"
# --- Configuration ---
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") # Use your environment variable
LLM_MODEL_NAME = "gemini-1.5-flash-latest"
GEMINI_EMBEDDING_MODEL_NAME = "models/embedding-001" # Standard embedding model
# Generation configuration for the LLM
GENERATION_CONFIG_PARAMS = {
"temperature": 0.2,
"top_p": 1.0,
"top_k": 32,
"max_output_tokens": 4096, # Increased for potentially longer code/explanations
}
# Safety settings for Gemini
DEFAULT_SAFETY_SETTINGS = {
genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai_types.HarmBlockThreshold.BLOCK_NONE,
genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai_types.HarmBlockThreshold.BLOCK_NONE,
genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai_types.HarmBlockThreshold.BLOCK_NONE,
genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai_types.HarmBlockThreshold.BLOCK_NONE,
}
# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
# Initialize Gemini Client globally if API key is available
if GEMINI_API_KEY:
try:
genai.configure(api_key=GEMINI_API_KEY)
logging.info(f"Gemini API key configured. Target model for generation: '{LLM_MODEL_NAME}', Embedding model: '{GEMINI_EMBEDDING_MODEL_NAME}'")
except Exception as e:
logging.error(f"Failed to configure Gemini API: {e}", exc_info=True)
else:
logging.warning("GEMINI_API_KEY environment variable not set. LLM and Embedding functionalities will be limited.")
# --- RAG Documents Definition (as in your example) ---
# This will be used by the AdvancedRAGSystem.
# You can replace this with more relevant documents for your LinkedIn dashboard context if needed.
rag_documents_data = {
'Title': [
"Employer Branding Best Practices 2024",
"Attracting Tech Talent in Competitive Markets",
"Understanding Company Culture for Talent Acquisition",
"Diversity and Inclusion in Modern Hiring Processes",
"Leveraging LinkedIn Data for Recruitment Insights",
"Analyzing Employee Engagement Metrics",
"Content Strategies for LinkedIn Company Pages"
],
'Text': [
"Focus on authentic employee stories and showcase your company's mission. Transparency in compensation and benefits is key. Leverage social media, especially LinkedIn, to highlight your work environment and values. Regularly share updates about company achievements and employee successes.",
"Tech candidates value challenging projects, opportunities for learning new technologies, and a flexible work culture. Highlight your tech stack, innovation efforts, and career development paths. Competitive salaries and benefits are standard expectations.",
"Company culture is defined by shared values, beliefs, and behaviors. It's crucial for attracting and retaining talent that aligns with the organization. Assess culture through employee surveys, feedback sessions, and by observing daily interactions. Promote a positive culture actively.",
"Promote diversity and inclusion by using inclusive language in job descriptions, ensuring diverse interview panels, and highlighting D&I initiatives. Track diversity metrics and be transparent about your goals and progress. An inclusive culture boosts innovation.",
"LinkedIn data provides rich insights into talent pools, competitor strategies, and industry trends. Analyze follower demographics, content engagement, and employee advocacy to refine your employer branding and recruitment efforts. Use LinkedIn Analytics effectively.",
"High employee engagement correlates with better retention and productivity. Key metrics include employee Net Promoter Score (eNPS), satisfaction surveys, and participation in company initiatives. Address feedback promptly to foster a positive work environment.",
"Develop a content calendar for your LinkedIn Company Page that includes a mix of thought leadership, company news, employee spotlights, job postings, and industry insights. Use visuals and videos to increase engagement. Analyze post performance to optimize your strategy."
]
}
df_rag_documents = pd.DataFrame(rag_documents_data)
# --- Schema Representation ---
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str:
"""Generates a string representation of a DataFrame's schema and a sample of its data."""
if df.empty:
return f"Schema for DataFrame '{df_name}':\n - DataFrame is empty.\n"
cols = df.columns.tolist()
dtypes = df.dtypes.to_dict()
schema_str = f"Schema for DataFrame 'df_{df_name}':\n" # Note: using df_ prefix for LLM
for col in cols:
schema_str += f" - Column '{col}': {dtypes[col]}\n"
# Add notes for complex data types or common pitfalls
for col in cols:
if 'date' in col.lower() or 'time' in col.lower():
schema_str += f" - Note: Column '{col}' seems to be date/time related. Ensure it is in datetime format for time-series analysis (e.g., using pd.to_datetime(df_{df_name}['{col}'])).\n"
if df[col].apply(type).eq(list).any() or df[col].apply(type).eq(dict).any():
schema_str += f" - Note: Column '{col}' may contain list-like or dict-like data (e.g., skills: ['Python', 'SQL']). Use operations like .explode() or .apply(pd.Series) or .apply(lambda x: ...) for lists/dicts.\n"
if df[col].dtype == 'object' and df[col].nunique() < 20 and df.shape[0] > 20: # Potential categorical
schema_str += f" - Note: Column '{col}' is of type object and has few unique values; it might be categorical. Use .value_counts() for distribution.\n"
schema_str += f"Sample of first 2 rows of 'df_{df_name}':\n{df.head(2).to_string()}\n"
return schema_str
def get_all_schemas_representation(dataframes_dict: dict) -> str:
"""Generates a combined string representation of schemas for all DataFrames."""
full_schema_str = "You have access to the following Pandas DataFrames. In your Python code, refer to them with the 'df_' prefix (e.g., df_follower_stats, df_posts).\n\n"
for name, df_instance in dataframes_dict.items():
full_schema_str += get_schema_representation(name, df_instance) + "\n"
return full_schema_str
# --- Advanced RAG System ---
class AdvancedRAGSystem:
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
if not GEMINI_API_KEY:
logging.warning("RAG System: GEMINI_API_KEY not set. Embeddings will not be generated.")
self.documents_df = documents_df.copy() # Still keep the df for potential non-embedding use
if 'Embeddings' not in self.documents_df.columns:
self.documents_df['Embeddings'] = pd.Series(dtype='object')
self.embedding_model_name = embedding_model_name
self.embeddings_generated = False
return
self.documents_df = documents_df.copy()
self.embedding_model_name = embedding_model_name
self.embeddings_generated = False
try:
self._precompute_embeddings()
self.embeddings_generated = True
logging.info("AdvancedRAGSystem Initialized and embeddings precomputed.")
except Exception as e:
logging.error(f"Error during RAG embedding precomputation: {e}", exc_info=True)
if 'Embeddings' not in self.documents_df.columns:
self.documents_df['Embeddings'] = pd.Series(dtype='object')
def _embed_fn(self, title: str, text: str) -> list[float]:
try:
if not hasattr(genai, 'embed_content') or not self.embeddings_generated: # Check if genai is usable
logging.warning("genai.embed_content not available or embeddings not generated. Returning zero vector.")
return [0.0] * 768
embedding_result = genai.embed_content(
model=self.embedding_model_name,
content=text,
task_type="retrieval_document",
title=title
)
return embedding_result["embedding"]
except Exception as e:
logging.error(f"Error embedding content '{title}': {e}", exc_info=True)
return [0.0] * 768
def _precompute_embeddings(self):
if 'Embeddings' not in self.documents_df.columns:
self.documents_df['Embeddings'] = pd.Series(dtype='object')
# Only compute for rows where 'Embeddings' is None, not a list, or a zero vector
for index, row in self.documents_df.iterrows():
current_embedding = row['Embeddings']
is_valid_embedding = isinstance(current_embedding, list) and len(current_embedding) > 0 and sum(abs(x) for x in current_embedding) > 1e-6 # Check if not zero vector
if not is_valid_embedding:
self.documents_df.at[index, 'Embeddings'] = self._embed_fn(row['Title'], row['Text'])
logging.info("Embeddings precomputation finished.")
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str:
if not self.embeddings_generated or not hasattr(genai, 'embed_content') or 'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all():
logging.warning("RAG System: Cannot retrieve info due to missing API key, genai functions, or empty/missing embeddings.")
return "\n[RAG Context]\nNo specific pre-defined context found (RAG system inactive or no embeddings).\n"
try:
query_embedding_result = genai.embed_content(
model=self.embedding_model_name,
content=query_text,
task_type="retrieval_query"
)
query_embedding = np.array(query_embedding_result["embedding"])
valid_embeddings_df = self.documents_df.dropna(subset=['Embeddings'])
valid_embeddings_df = valid_embeddings_df[valid_embeddings_df['Embeddings'].apply(lambda x: isinstance(x, list) and len(x) > 0 and sum(abs(val) for val in x) > 1e-6)]
if valid_embeddings_df.empty:
logging.warning("No valid document embeddings found for RAG.")
return "\n[RAG Context]\nNo valid document embeddings available for retrieval.\n"
document_embeddings = np.stack(valid_embeddings_df['Embeddings'].apply(np.array).values)
if query_embedding.shape[0] != document_embeddings.shape[1]:
logging.error(f"Query embedding dim ({query_embedding.shape[0]}) != Document embedding dim ({document_embeddings.shape[1]})")
return "\n[RAG Context]\nEmbedding dimension mismatch.\n"
dot_products = np.dot(document_embeddings, query_embedding)
# Get top_k indices, ensure top_k is not greater than available docs
num_available_docs = len(valid_embeddings_df)
actual_top_k = min(top_k, num_available_docs)
if actual_top_k == 0:
return "\n[RAG Context]\nNo documents to retrieve from.\n"
if actual_top_k == 1 and num_available_docs > 0:
idx = [np.argmax(dot_products)]
elif num_available_docs > 0 :
idx = np.argsort(dot_products)[-actual_top_k:][::-1]
else: # Should not happen if actual_top_k > 0
idx = []
relevant_passages = ""
for i in idx:
passage_title = valid_embeddings_df.iloc[i]['Title']
passage_text = valid_embeddings_df.iloc[i]['Text']
relevant_passages += f"\n[RAG Context from: '{passage_title}']\n{passage_text}\n"
logging.info(f"RAG System retrieved: {relevant_passages[:200]}...")
return relevant_passages if relevant_passages else "\n[RAG Context]\nNo highly relevant passages found.\n"
except Exception as e:
logging.error(f"Error retrieving relevant info from RAG: {e}", exc_info=True)
return f"\n[RAG Context]\nError during RAG retrieval: {str(e)}\n"
# --- PandasLLM Class (Gemini-Powered) ---
class PandasLLM:
def __init__(self, llm_model_name: str, generation_config_params: dict, safety_settings: dict, data_privacy=True, force_sandbox=True):
self.llm_model_name = llm_model_name
self.generation_config_params = generation_config_params
self.safety_settings = safety_settings
self.data_privacy = data_privacy
self.force_sandbox = force_sandbox # If True, LLM must output Python code to be exec'd
if not GEMINI_API_KEY:
logging.warning("PandasLLM: GEMINI_API_KEY not set. LLM functionalities will be limited.")
self.model = None
else:
try:
self.model = genai.GenerativeModel(
self.llm_model_name,
safety_settings=self.safety_settings
)
logging.info(f"PandasLLM Initialized with Gemini model '{self.llm_model_name}'. data_privacy={data_privacy}, force_sandbox={force_sandbox}")
except Exception as e:
logging.error(f"Failed to initialize GenerativeModel '{self.llm_model_name}': {e}", exc_info=True)
self.model = None
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str:
if not self.model:
logging.error("PandasLLM: Gemini model not initialized. Cannot call API.")
return "# Error: Gemini model not available. Check API key and configuration."
# Construct content for Gemini API
# The new API expects a list of Content objects, or a list of dicts
# For chat-like interaction, history should be managed.
# For single-turn code generation, a simple user prompt might suffice.
# For now, let's assume single turn for code generation for simplicity in PandasLLM context
# If this were a conversational agent, history would be crucial.
contents_for_api = [{"role": "user", "parts": [{"text": prompt_text}]}]
if history: # If history is provided, prepend it
# Ensure history is in the correct format [{'role':'user/model', 'parts':[{'text':...}]}]
# This part might need adjustment based on how history is structured by the calling agent
formatted_history = []
for entry in history:
role = entry.get("role", "user") # Default to user if role not specified
if role == "assistant": role = "model" # Gemini uses "model" for assistant
formatted_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
contents_for_api = formatted_history + contents_for_api
try:
gen_config_obj = genai_types.GenerateContentConfig(**self.generation_config_params)
except Exception as e:
logging.error(f"Error creating GenerateContentConfig: {e}. Using dict directly.")
gen_config_obj = self.generation_config_params
logging.info(f"\n--- Calling Gemini API with prompt (first 500 chars of last message): ---\n{contents_for_api[-1]['parts'][0]['text'][:500]}...\n-------------------------------------------------------\n")
try:
# Using asyncio.to_thread for the blocking SDK call
response = await asyncio.to_thread(
self.model.generate_content,
contents=contents_for_api, # Pass the constructed content
generation_config=gen_config_obj,
# stream=False # Ensure non-streaming for this setup
)
if response.prompt_feedback and response.prompt_feedback.block_reason:
reason = response.prompt_feedback.block_reason
reason_name = getattr(reason, 'name', str(reason)) # Handle if reason is enum or string
logging.warning(f"Gemini API call blocked by prompt feedback: {reason_name}")
return f"# Error: Prompt blocked due to content policy: {reason_name}."
# Try to extract text, accounting for different response structures
llm_output = ""
if hasattr(response, 'text') and response.text:
llm_output = response.text
elif response.candidates:
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
# Check finish reason if output is empty
if not llm_output:
finish_reason_val = candidate.finish_reason
finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val)) # Handle enum or string
logging.warning(f"No text content in response candidate. Finish reason: {finish_reason}")
if finish_reason == "SAFETY": # Check against genai_types.FinishReason.SAFETY if available
return f"# Error: Response generation stopped due to safety reasons ({finish_reason})."
elif finish_reason == "RECITATION":
return f"# Error: Response generation stopped due to recitation policy ({finish_reason})."
return f"# Error: The AI model returned an empty response. Finish reason: {finish_reason}."
else:
logging.warning("Gemini API response structure not recognized or empty.")
return "# Error: The AI model returned an unexpected or empty response structure."
logging.info(f"--- Gemini API Response (first 300 chars): ---\n{llm_output[:300]}...\n--------------------------------------------------\n")
return llm_output
except AttributeError as ae: # Catch issues with dummy genai objects if API key missing
logging.error(f"AttributeError during Gemini call (likely due to missing API key/dummy objects): {ae}", exc_info=True)
return f"# Error (Attribute): {type(ae).__name__} - {ae}. Check if GEMINI_API_KEY is set and google.genai library is correctly installed."
except Exception as e:
logging.error(f"Error calling Gemini API: {e}", exc_info=True)
# More specific error messages
if "API_KEY_INVALID" in str(e) or "API key not valid" in str(e):
return "# Error: Gemini API key is not valid. Please check your GEMINI_API_KEY environment variable."
if "400" in str(e) and "model" in str(e).lower() and ("not found" in str(e).lower() or "does not exist" in str(e).lower()):
return f"# Error: Gemini Model '{self.llm_model_name}' not found or not accessible with your API key. Check model name and permissions."
if "DeadlineExceeded" in str(e) or "504" in str(e):
return "# Error: The request to Gemini API timed out. Please try again later."
if "PermissionDenied" in str(e) or "403" in str(e):
return "# Error: Permission denied. Check if your API key has access to the model or required services."
return f"# Error: An unexpected error occurred while contacting the AI model: {type(e).__name__} - {str(e)[:100]}..."
async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str:
"""
Sends a prompt to the LLM and optionally executes the returned Python code in a sandbox.
dataframes_dict: Keys are 'base_name' (e.g., 'profiles'), values are pd.DataFrame.
In exec, they are available as 'df_base_name'.
history: Optional chat history for conversational context.
"""
llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history)
if self.force_sandbox:
# Attempt to extract Python code block
code_to_execute = ""
if "```python" in llm_response_text:
try:
code_to_execute = llm_response_text.split("```python\n", 1)[1].split("\n```", 1)[0]
except IndexError:
# This might happen if the format is slightly off, e.g. no newline after ```python
try:
code_to_execute = llm_response_text.split("```python", 1)[1].split("```", 1)[0]
if code_to_execute.startswith("\n"): code_to_execute = code_to_execute[1:] # remove leading newline
if code_to_execute.endswith("\n"): code_to_execute = code_to_execute[:-1] # remove trailing newline
except IndexError:
code_to_execute = "" # Fallback, code not extracted
logging.warning("Could not extract Python code using primary or secondary split method.")
if llm_response_text.startswith("# Error:") or not code_to_execute:
error_prefix = "LLM did not return a valid Python code block or an error occurred."
if llm_response_text.startswith("# Error:"): error_prefix = "An error occurred during LLM call."
elif not code_to_execute: error_prefix = "Could not extract Python code from LLM response."
# Sanitize llm_response_text before printing to avoid breaking f-string or print
safe_llm_response = str(llm_response_text).replace("'''", "'").replace('"""', '"')
code_for_error_display = f"print(f'''{error_prefix}\\nRaw LLM Response (may be truncated):\\n{safe_llm_response[:1000]}''')"
logging.warning(f"Problem with LLM response for sandbox: {error_prefix}")
# Fallback to printing the raw response or error
llm_response_text_for_sandbox_error = code_for_error_display
logging.info(f"\n--- Code to Execute (from LLM, if sandbox): ---\n{code_to_execute}\n------------------------------------------------\n")
# Define a restricted set of built-ins
safe_builtins = {name: obj for name, obj in __builtins__.__dict__.items() if not name.startswith('_')}
# More aggressive removal (example, adjust as needed for security)
# For a web app, this sandboxing is CRITICAL.
# Consider using a dedicated sandboxing library if security is paramount.
unsafe_builtins = ['eval', 'exec', 'open', 'compile', 'input', 'memoryview', 'vars', 'globals', 'locals', '__import__']
for ub in unsafe_builtins:
safe_builtins.pop(ub, None)
# Prepare globals for exec: pandas, numpy, dataframes, and restricted builtins
exec_globals = {'pd': pd, 'np': np, '__builtins__': safe_builtins}
for name, df_instance in dataframes_dict.items():
exec_globals[f"df_{name}"] = df_instance # e.g. df_follower_stats, df_posts
from io import StringIO
import sys
old_stdout = sys.stdout
sys.stdout = captured_output = StringIO()
final_output_str = ""
try:
if code_to_execute: # Only execute if code was extracted
exec(code_to_execute, exec_globals, {}) # Empty locals
output_val = captured_output.getvalue()
final_output_str = output_val if output_val else "# Code executed successfully, but no explicit print() output was generated by the code."
logging.info(f"--- Sandbox Execution Output: ---\n{final_output_str}\n-------------------------\n")
else: # No code to execute, use the prepared error message
exec(llm_response_text_for_sandbox_error, exec_globals, {})
final_output_str = captured_output.getvalue()
logging.warning(f"--- Sandbox Fallback Output (No Code Executed): ---\n{final_output_str}\n-------------------------\n")
except Exception as e:
error_msg = f"# Error executing LLM-generated code:\n# {type(e).__name__}: {str(e)}\n# --- Code that caused error: ---\n{textwrap.indent(code_to_execute, '# ')}"
final_output_str = error_msg
logging.error(error_msg, exc_info=False) # exc_info=False to avoid huge traceback in Gradio UI
finally:
sys.stdout = old_stdout # Reset stdout
return final_output_str
else: # Not force_sandbox, return LLM text directly
return llm_response_text
# --- Employer Branding Agent ---
class EmployerBrandingAgent:
def __init__(self, llm_model_name: str, generation_config_params: dict, safety_settings: dict,
all_dataframes: dict, rag_documents_df: pd.DataFrame, embedding_model_name: str,
data_privacy=True, force_sandbox=True):
self.pandas_llm = PandasLLM(llm_model_name, generation_config_params, safety_settings, data_privacy, force_sandbox)
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
self.all_dataframes = all_dataframes # Keys are 'base_name', values are pd.DataFrame
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
self.chat_history = [] # Stores conversation history for this agent instance
logging.info("EmployerBrandingAgent Initialized.")
def _build_prompt(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
# Base prompt
prompt = f"You are a helpful and expert '{role}'. Your primary goal is to assist with analyzing LinkedIn-related data using Pandas DataFrames.\n"
prompt += "You will be provided with schemas for available Pandas DataFrames and a user query.\n"
if self.pandas_llm.data_privacy:
prompt += "IMPORTANT: Be mindful of data privacy. Do not output raw Personally Identifiable Information (PII) like names or specific user details unless explicitly asked and absolutely necessary for the query. Summarize or aggregate data where possible.\n"
if self.pandas_llm.force_sandbox:
prompt += "Your main task is to GENERATE PYTHON CODE using the Pandas library to answer the user query based on the provided DataFrames. Output ONLY the Python code block.\n"
prompt += "The available DataFrames are already loaded and can be accessed by their dictionary keys prefixed with 'df_' (e.g., df_follower_stats, df_posts) within the execution environment.\n"
prompt += "Example of accessing a DataFrame: `df_follower_stats['country']`.\n"
prompt += "Your Python code MUST include `print()` statements for any results, DataFrames, or values you want to display. The output of these print statements will be the final answer.\n"
prompt += "If a column contains lists (e.g., 'skills' in a hypothetical 'df_employees'), you might need to use methods like `.explode()` or `.apply(pd.Series)` or `.apply(lambda x: ...)` for analysis.\n"
prompt += "If the query is ambiguous or requires clarification, ask for it instead of making assumptions. If the query cannot be answered with the given data, state that clearly.\n"
prompt += "If the query is not about data analysis or code generation (e.g. 'hello', 'how are you?'), respond politely and briefly, do not attempt to generate code.\n"
prompt += "Structure your code clearly. Add comments (#) to explain each step of your logic.\n"
else: # Textual response mode
prompt += "Your task is to analyze the data and provide a comprehensive textual answer to the user query. You can explain your reasoning step-by-step.\n"
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
prompt += self.schemas_representation
# RAG Context (only add if relevant context is found)
rag_context = self.rag_system.retrieve_relevant_info(user_query)
if rag_context and "[RAG Context]" in rag_context and "No specific pre-defined context found" not in rag_context and "No highly relevant passages found" not in rag_context:
prompt += f"\n--- ADDITIONAL CONTEXT (from internal knowledge base, consider this information) ---\n{rag_context}\n"
prompt += f"\n--- USER QUERY ---\n{user_query}\n"
if task_decomposition_hint:
prompt += f"\n--- GUIDANCE FOR ANALYSIS (Task Decomposition) ---\n{task_decomposition_hint}\n"
if cot_hint:
if self.pandas_llm.force_sandbox:
prompt += "\n--- INSTRUCTIONS FOR PYTHON CODE GENERATION (Chain of Thought) ---\n"
prompt += "1. Understand the query: What specific information is requested?\n"
prompt += "2. Identify relevant DataFrame(s) and column(s) from the schemas provided.\n"
prompt += "3. Plan the steps: Outline the Pandas operations needed (filtering, grouping, aggregation, merging, etc.) as comments in your code.\n"
prompt += "4. Write the code: Implement the steps using Pandas. Remember to use `df_name_of_dataframe` (e.g. `df_follower_stats`).\n"
prompt += "5. Ensure output: Use `print()` for all results that should be displayed. For DataFrames, you can print the DataFrame directly, or `df.to_string()` if it's large.\n"
prompt += "6. Review: Check for correctness, efficiency, and adherence to the prompt (especially the `print()` requirement).\n"
prompt += "7. Generate ONLY the Python code block starting with ```python and ending with ```. No explanations outside the code block's comments.\n"
else: # Textual CoT
prompt += "\n--- INSTRUCTIONS FOR RESPONSE (Chain of Thought) ---\n"
prompt += "1. Understand the query fully.\n"
prompt += "2. Identify the relevant data sources (DataFrames and columns).\n"
prompt += "3. Explain your analytical approach step-by-step.\n"
prompt += "4. Perform the analysis (mentally or by outlining the steps).\n"
prompt += "5. Present the findings clearly and concisely. If you performed calculations, show or describe them.\n"
prompt += "6. If applicable, incorporate insights from the 'ADDITIONAL CONTEXT' (RAG system).\n"
return prompt
async def process_query(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
logging.info(f"\n=== Processing Query for Role: {role} ===")
logging.info(f"User Query: {user_query}")
# Add user query to chat history
self.chat_history.append({"role": "user", "content": user_query})
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
# Pass relevant parts of chat history to pandas_llm.query if needed for context
# For now, PandasLLM's _call_gemini_api_async is set up for single turn for code gen,
# but can be adapted if conversational context for code gen becomes important.
# The full_prompt itself is rebuilt each time, incorporating the latest user_query.
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=self.chat_history[:-1]) # Pass history excluding current query
# Add assistant response to chat history
self.chat_history.append({"role": "assistant", "content": response_text})
# Limit history size to avoid overly long prompts in future turns (e.g., last 10 messages)
MAX_HISTORY_TURNS = 5 # 5 pairs of user/assistant messages
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
return response_text
def update_dataframes(self, new_dataframes: dict):
"""Updates the agent's DataFrames and their schema representation."""
self.all_dataframes = new_dataframes
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
logging.info("EmployerBrandingAgent DataFrames updated.")
def clear_chat_history(self):
self.chat_history = []
logging.info("EmployerBrandingAgent chat history cleared.")