LinkedinMonitor / eb_agent_module.py
GuglielmoTor's picture
Update eb_agent_module.py
9ce5589 verified
raw
history blame
33 kB
# eb_agent_module.py
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap
# Attempt to import Google Generative AI and related types
try:
from google import generativeai as genai # Renamed for clarity to avoid conflict
from google.generativeai import types as genai_types
# from google.generativeai import GenerationConfig # For direct use if needed
# from google.generativeai.types import HarmCategory, HarmBlockThreshold, SafetySetting # For direct use
except ImportError:
print("Google Generative AI library not found. Please install it: pip install google-generativeai")
# Define dummy classes/functions if the import fails, to allow the rest of the script to be parsed
class genai: # type: ignore
@staticmethod
def configure(api_key):
print(f"Dummy genai.configure called with API key: {'SET' if api_key else 'NOT SET'}")
# Dummy Client and related structures
class Client:
def __init__(self, api_key=None): # api_key is optional for Client constructor
self.api_key = api_key
self.models = self._Models() # This is the service client for models
print(f"Dummy genai.Client initialized {'with api_key' if api_key else '(global API key expected)'}.")
class _Models: # Represents the model service client
async def generate_content_async(self, model=None, contents=None, generation_config=None, safety_settings=None, stream=False, tools=None, tool_config=None): # Matched real signature better
print(f"Dummy genai.Client.models.generate_content_async called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}")
class DummyPart:
def __init__(self, text): self.text = text
class DummyContent:
def __init__(self): self.parts = [DummyPart("# Dummy response from dummy client's async generate_content")]
class DummyCandidate:
def __init__(self):
self.content = DummyContent()
self.finish_reason = genai_types.FinishReason.STOP # Use dummy FinishReason
self.safety_ratings = []
self.token_count = 0
self.index = 0
class DummyResponse:
def __init__(self):
self.candidates = [DummyCandidate()]
self.prompt_feedback = self._PromptFeedback()
self.text = "# Dummy response text from dummy client's async generate_content"
class _PromptFeedback:
def __init__(self):
self.block_reason = None
self.safety_ratings = []
return DummyResponse()
def generate_content(self, model=None, contents=None, generation_config=None, safety_settings=None, stream=False, tools=None, tool_config=None): # Matched real signature better
print(f"Dummy genai.Client.models.generate_content called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}")
# Re-using the async dummy structure for simplicity
class DummyPart:
def __init__(self, text): self.text = text
class DummyContent:
def __init__(self): self.parts = [DummyPart("# Dummy response from dummy client's generate_content")]
class DummyCandidate:
def __init__(self):
self.content = DummyContent()
self.finish_reason = genai_types.FinishReason.STOP
self.safety_ratings = []
self.token_count = 0
self.index = 0
class DummyResponse:
def __init__(self):
self.candidates = [DummyCandidate()]
self.prompt_feedback = self._PromptFeedback()
self.text = "# Dummy response text from dummy client's generate_content"
class _PromptFeedback:
def __init__(self):
self.block_reason = None
self.safety_ratings = []
return DummyResponse()
@staticmethod
def GenerativeModel(model_name, generation_config=None, safety_settings=None, system_instruction=None): # Kept for AdvancedRAGSystem if it uses it, or if user switches back
print(f"Dummy genai.GenerativeModel called for model: {model_name} (This might be unused if Client approach is preferred)")
# ... (rest of DummyGenerativeModel as before, for completeness) ...
class DummyGenerativeModel:
def __init__(self, model_name_in, generation_config_in, safety_settings_in, system_instruction_in):
self.model_name = model_name_in
async def generate_content_async(self, contents, stream=False):
class DummyPart:
def __init__(self, text): self.text = text
class DummyContent:
def __init__(self): self.parts = [DummyPart(f"# Dummy response from dummy GenerativeModel ({self.model_name})")]
class DummyCandidate:
def __init__(self):
self.content = DummyContent(); self.finish_reason = genai_types.FinishReason.STOP; self.safety_ratings = []
class DummyResponse:
def __init__(self):
self.candidates = [DummyCandidate()]; self.prompt_feedback = None; self.text = f"# Dummy GM response"
return DummyResponse()
return DummyGenerativeModel(model_name, generation_config, safety_settings, system_instruction)
@staticmethod
def embed_content(model, content, task_type, title=None):
print(f"Dummy genai.embed_content called for model: {model}, task_type: {task_type}, title: {title}")
return {"embedding": [0.1] * 768}
class genai_types: # type: ignore
@staticmethod
def GenerationConfig(**kwargs):
print(f"Dummy genai_types.GenerationConfig created with: {kwargs}")
return dict(kwargs)
@staticmethod
def SafetySetting(category, threshold):
print(f"Dummy SafetySetting created: category={category}, threshold={threshold}")
return {"category": category, "threshold": threshold}
class HarmCategory:
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"; HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"; HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"; HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"; HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class HarmBlockThreshold:
BLOCK_NONE = "BLOCK_NONE"; BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"; BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"; BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
class FinishReason:
FINISH_REASON_UNSPECIFIED = "UNSPECIFIED"; STOP = "STOP"; MAX_TOKENS = "MAX_TOKENS"; SAFETY = "SAFETY"; RECITATION = "RECITATION"; OTHER = "OTHER"
# Dummy for BlockedReason if needed by response parsing
class BlockedReason:
BLOCKED_REASON_UNSPECIFIED = "BLOCKED_REASON_UNSPECIFIED"
SAFETY = "SAFETY"
OTHER = "OTHER"
# --- Configuration ---
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
# User-specified model names:
# LLM_MODEL_NAME = "gemini-2.0-flash" # Original
LLM_MODEL_NAME = "gemini-2.0-flash"
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
# Base generation configuration for the LLM
GENERATION_CONFIG_PARAMS = {
"temperature": 0.3,
"top_p": 1.0,
"top_k": 32,
"max_output_tokens": 8192,
}
# Default safety settings list for Gemini
try:
DEFAULT_SAFETY_SETTINGS = [
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
]
except AttributeError as e:
logging.warning(f"Could not define DEFAULT_SAFETY_SETTINGS using real genai_types: {e}. Using placeholder list of dicts.")
DEFAULT_SAFETY_SETTINGS = [
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
]
# Logging setup
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(filename)s:%(lineno)d - %(message)s')
if GEMINI_API_KEY:
try:
genai.configure(api_key=GEMINI_API_KEY)
logging.info(f"Gemini API key configured globally.")
except Exception as e:
logging.error(f"Failed to configure Gemini API globally: {e}", exc_info=True)
else:
logging.warning("GEMINI_API_KEY environment variable not set. Agent will use dummy responses if real genai library is not fully mocked or if API calls fail.")
# --- RAG Documents Definition (Example) ---
rag_documents_data = {
'Title': ["Employer Branding Best Practices 2024", "Attracting Tech Talent", "Employee Advocacy", "Gen Z Expectations"],
'Text': ["Focus on authentic employee stories...", "Tech candidates value challenging projects...", "Encourage employees to share experiences...", "Gen Z values purpose-driven work..."]
}
df_rag_documents = pd.DataFrame(rag_documents_data)
# --- Schema Representation ---
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str:
if not isinstance(df, pd.DataFrame): return f"Schema for item '{df_name}': Not a DataFrame.\n"
if df.empty: return f"Schema for DataFrame 'df_{df_name}': Empty.\n"
schema_str = f"DataFrame 'df_{df_name}':\n Columns: {df.columns.tolist()}\n Shape: {df.shape}\n"
if not df.empty: schema_str += f" Sample Data (first 2 rows):\n{textwrap.indent(df.head(2).to_string(), ' ')}\n"
else: schema_str += " Sample Data: DataFrame is empty.\n"
return schema_str
def get_all_schemas_representation(dataframes_dict: dict) -> str:
if not dataframes_dict: return "No DataFrames provided.\n"
return "".join(get_schema_representation(name, df) for name, df in dataframes_dict.items())
# --- Advanced RAG System ---
class AdvancedRAGSystem:
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
self.embedding_model_name = embedding_model_name
self.documents_df = documents_df.copy()
self.embeddings_generated = False
# Check if genai.embed_content is the real one or our dummy
self.client_available = hasattr(genai, 'embed_content') and not (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content'))
if GEMINI_API_KEY and self.client_available:
try:
self._precompute_embeddings()
self.embeddings_generated = True
logging.info(f"RAG embeddings precomputed using '{self.embedding_model_name}'.")
except Exception as e: logging.error(f"RAG precomputation error: {e}", exc_info=True)
else:
logging.warning(f"RAG embeddings not precomputed. Key: {bool(GEMINI_API_KEY)}, embed_content_ok: {self.client_available}.")
def _embed_fn(self, title: str, text: str) -> list[float]:
if not self.client_available: return [0.0] * 768
try:
content_to_embed = text if text else title
if not content_to_embed: return [0.0] * 768
return genai.embed_content(model=self.embedding_model_name, content=content_to_embed, task_type="retrieval_document", title=title if title else None)["embedding"]
except Exception as e:
logging.error(f"Error in _embed_fn for '{title}': {e}", exc_info=True)
return [0.0] * 768
def _precompute_embeddings(self):
if 'Embeddings' not in self.documents_df.columns: self.documents_df['Embeddings'] = pd.Series(dtype='object')
mask = (self.documents_df['Text'].notna() & (self.documents_df['Text'] != '')) | (self.documents_df['Title'].notna() & (self.documents_df['Title'] != ''))
if not mask.any(): logging.warning("No content for RAG embeddings."); return
self.documents_df.loc[mask, 'Embeddings'] = self.documents_df[mask].apply(lambda row: self._embed_fn(row.get('Title', ''), row.get('Text', '')), axis=1)
logging.info(f"Applied RAG embedding function to {mask.sum()} rows.")
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str:
if not self.client_available: return "\n[RAG Context]\nEmbedding client not available.\n"
if not self.embeddings_generated or 'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all():
return "\n[RAG Context]\nEmbeddings not ready for RAG.\n"
try:
query_embedding = np.array(genai.embed_content(model=self.embedding_model_name, content=query_text, task_type="retrieval_query")["embedding"])
valid_df = self.documents_df.dropna(subset=['Embeddings'])
valid_df = valid_df[valid_df['Embeddings'].apply(lambda x: isinstance(x, (list, np.ndarray)) and len(x) > 0)]
if valid_df.empty: return "\n[RAG Context]\nNo valid document embeddings.\n"
doc_embeddings = np.stack(valid_df['Embeddings'].apply(np.array).values)
if query_embedding.shape[0] != doc_embeddings.shape[1]: return "\n[RAG Context]\nEmbedding dimension mismatch.\n"
dot_products = np.dot(doc_embeddings, query_embedding)
num_to_retrieve = min(top_k, len(valid_df))
if num_to_retrieve == 0: return "\n[RAG Context]\nNo relevant passages found (num_to_retrieve is 0).\n"
idx = np.argsort(dot_products)[-num_to_retrieve:][::-1]
passages = "".join([f"\n[RAG Context from: '{valid_df.iloc[i]['Title']}']\n{valid_df.iloc[i]['Text']}\n" for i in idx if i < len(valid_df)])
return passages if passages else "\n[RAG Context]\nNo relevant passages found after search.\n"
except Exception as e:
logging.error(f"Error in RAG retrieve_relevant_info: {e}", exc_info=True)
return f"\n[RAG Context]\nError during RAG retrieval: {type(e).__name__} - {e}\n"
# --- PandasLLM Class (Gemini-Powered using genai.Client) ---
class PandasLLM:
def __init__(self, llm_model_name: str,
generation_config_dict: dict,
safety_settings_list: list,
data_privacy=True, force_sandbox=True):
self.llm_model_name = llm_model_name
self.generation_config_dict = generation_config_dict # Will be passed to API call
self.safety_settings_list = safety_settings_list # Will be passed to API call
self.data_privacy = data_privacy
self.force_sandbox = force_sandbox
self.client = None
self.model_service = None # This will be client.models
# Check if genai.Client is the real one or our dummy
is_real_genai_client = hasattr(genai, 'Client') and not (hasattr(genai.Client, '__func__') and genai.Client.__func__.__qualname__.startswith('genai.Client'))
if not GEMINI_API_KEY and is_real_genai_client: # Real client but no API key
logging.warning(f"PandasLLM: GEMINI_API_KEY not set, but real 'genai.Client' seems available. API calls may fail if global config is not sufficient.")
# Proceed to initialize client; it might work if genai.configure() was successful without explicit key here
# or if the environment provides credentials in another way.
try:
self.client = genai.Client() # API key is usually set via genai.configure or environment
self.model_service = self.client.models
logging.info(f"PandasLLM: Initialized with genai.Client().models for '{self.llm_model_name}'.")
except Exception as e:
logging.error(f"Failed to initialize PandasLLM with genai.Client: {e}", exc_info=True)
# Fallback to dummy if real initialization fails, to prevent crashes
if not is_real_genai_client: # If it was already the dummy, re-initialize dummy
self.client = genai.Client()
self.model_service = self.client.models
logging.warning("PandasLLM: Falling back to DUMMY genai.Client due to real initialization error or it was already dummy.")
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str:
if not self.model_service:
logging.error("PandasLLM: Model service (client.models) not available. Cannot call API.")
return "# Error: Gemini model service not available for API call."
gemini_history = []
if history:
for entry in history:
role = "model" if entry.get("role") == "assistant" else entry.get("role", "user")
gemini_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
current_content = [{"role": "user", "parts": [{"text": prompt_text}]}]
contents_for_api = gemini_history + current_content
# Prepare model ID (e.g., "models/gemini-2.0-flash")
model_id_for_api = self.llm_model_name
if not model_id_for_api.startswith("models/"):
model_id_for_api = f"models/{model_id_for_api}"
# Prepare generation config object
api_generation_config = None
if self.generation_config_dict:
try:
api_generation_config = genai_types.GenerationConfig(**self.generation_config_dict)
except Exception as e_cfg:
logging.error(f"Error creating GenerationConfig object: {e_cfg}. Using dict as fallback.")
api_generation_config = self.generation_config_dict # Fallback to dict
logging.info(f"\n--- Calling Gemini API via Client (model: {model_id_for_api}) ---\nConfig: {api_generation_config}\nSafety: {bool(self.safety_settings_list)}\nContent (last part text): {contents_for_api[-1]['parts'][0]['text'][:100]}...\n")
try:
response = await self.model_service.generate_content_async(
model=model_id_for_api,
contents=contents_for_api,
generation_config=api_generation_config,
safety_settings=self.safety_settings_list
)
# ... (Response parsing logic remains largely the same as before) ...
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
# ... block reason handling ...
block_reason_val = response.prompt_feedback.block_reason
block_reason_str = str(block_reason_val.name if hasattr(block_reason_val, 'name') else block_reason_val)
logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}.")
return f"# Error: Prompt blocked by API. Reason: {block_reason_str}."
llm_output = ""
if hasattr(response, 'text') and isinstance(response.text, str):
llm_output = response.text
elif response.candidates:
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
if not llm_output and candidate.finish_reason:
# ... finish reason handling ...
finish_reason_val = candidate.finish_reason
finish_reason_str = str(finish_reason_val.name if hasattr(finish_reason_val, 'name') else finish_reason_val)
if finish_reason_str == "SAFETY": # or candidate.finish_reason == genai_types.FinishReason.SAFETY:
# ... safety message handling ...
logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}.")
return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}."
logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
else:
logging.error(f"Unexpected API response structure: {str(response)[:500]}")
return f"# Error: Unexpected API response structure: {str(response)[:200]}"
return llm_output
except genai_types.BlockedPromptException as bpe:
logging.error(f"Prompt blocked (BlockedPromptException): {bpe}", exc_info=True)
return f"# Error: Prompt blocked. Details: {bpe}"
except genai_types.StopCandidateException as sce:
logging.error(f"Candidate stopped (StopCandidateException): {sce}", exc_info=True)
return f"# Error: Content generation stopped. Details: {sce}"
except Exception as e:
logging.error(f"Error calling Gemini API via Client: {e}", exc_info=True)
return f"# Error during API call: {type(e).__name__} - {str(e)[:100]}."
async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str:
llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history)
if self.force_sandbox:
code_to_execute = ""
if "```python" in llm_response_text:
try:
code_block_match = llm_response_text.split("```python\n", 1)
if len(code_block_match) > 1: code_to_execute = code_block_match[1].split("\n```", 1)[0]
else:
code_block_match = llm_response_text.split("```python", 1)
if len(code_block_match) > 1:
code_to_execute = code_block_match[1].split("```", 1)[0]
if code_to_execute.startswith("\n"): code_to_execute = code_to_execute[1:]
except IndexError: code_to_execute = ""
if llm_response_text.startswith("# Error:") or not code_to_execute.strip():
logging.warning(f"LLM error or no code: {llm_response_text[:200]}")
if not code_to_execute.strip() and not llm_response_text.startswith("# Error:"):
if "```" not in llm_response_text and len(llm_response_text.strip()) > 0:
logging.info(f"LLM text output in sandbox mode: {llm_response_text[:200]}")
return llm_response_text
logging.info(f"\n--- Code to Execute: ---\n{code_to_execute}\n----------------------\n")
from io import StringIO
import sys
old_stdout, sys.stdout = sys.stdout, StringIO()
exec_globals = {'pd': pd, 'np': np}
if dataframes_dict:
for name, df_instance in dataframes_dict.items():
if isinstance(df_instance, pd.DataFrame): exec_globals[f"df_{name}"] = df_instance
else: logging.warning(f"Item '{name}' not a DataFrame.")
try:
exec(code_to_execute, exec_globals, {})
final_output_str = sys.stdout.getvalue()
if not final_output_str.strip():
if not any(ln.strip() and not ln.strip().startswith("#") for ln in code_to_execute.splitlines()):
return "# LLM generated only comments or empty code. No output."
return "# Code executed, but no print() output. Ensure print() for results."
return final_output_str
except Exception as e:
logging.error(f"Sandbox Exec Error: {e}\nCode:\n{code_to_execute}", exc_info=True)
indented_code = textwrap.indent(code_to_execute, '# ')
return f"# Sandbox Exec Error: {type(e).__name__}: {e}\n# Code:\n{indented_code}"
finally: sys.stdout = old_stdout
else: return llm_response_text
# --- Employer Branding Agent ---
class EmployerBrandingAgent:
def __init__(self, llm_model_name: str,
generation_config_dict: dict,
safety_settings_list: list,
all_dataframes: dict,
rag_documents_df: pd.DataFrame,
embedding_model_name: str,
data_privacy=True, force_sandbox=True):
self.pandas_llm = PandasLLM(llm_model_name, generation_config_dict, safety_settings_list, data_privacy, force_sandbox)
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
self.all_dataframes = all_dataframes if all_dataframes else {}
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
self.chat_history = []
logging.info("EmployerBrandingAgent Initialized (using Client API approach).")
def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
prompt = f"You are a highly skilled '{role}'. Your goal is to provide actionable employer branding insights by analyzing Pandas DataFrames and RAG documents.\n"
if self.pandas_llm.data_privacy: prompt += "IMPORTANT: Adhere to data privacy. Summarize/aggregate PII.\n"
if self.pandas_llm.force_sandbox:
prompt += "\n--- TASK: PYTHON CODE GENERATION FOR INSIGHTS ---\n"
prompt += "GENERATE PYTHON CODE using Pandas. The code's `print()` statements MUST output final textual insights/answers.\n"
prompt += "Output ONLY the Python code block (```python ... ```).\n"
prompt += "Access DataFrames as 'df_name' (e.g., `df_follower_stats`).\n"
prompt += "\n--- CRITICAL INSTRUCTIONS FOR PYTHON CODE OUTPUT ---\n"
prompt += "1. **Print Insights, Not Just Data:** `print()` clear, actionable insights. NOT raw DataFrames unless specifically asked for a table.\n"
prompt += " Good: `print(f'Insight: Theme {top_theme} has {engagement_increase}% higher engagement.')`\n"
prompt += " Avoid: `print(df_result)` (for insight queries).\n"
prompt += "2. **Synthesize with RAG:** Weave RAG takeaways into printed insights. Ex: `print(f'Data shows X. RAG says Y. Recommend Z.')`\n"
prompt += "3. **Comments & Clarity:** Write clean, commented code.\n"
prompt += "4. **Handle Issues in Code:** If ambiguous, `print()` a question. If data unavailable, `print()` explanation. For non-analytical queries, `print()` polite reply.\n"
prompt += "5. **Function Usage:** Call functions and `print()` their (insightful) results.\n"
else: # Not force_sandbox
prompt += "\n--- TASK: DIRECT TEXTUAL INSIGHT GENERATION ---\n"
prompt += "Analyze data and RAG, then provide a comprehensive textual answer with insights. Explain step-by-step.\n"
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
prompt += self.schemas_representation if self.schemas_representation.strip() != "No DataFrames provided." else "No DataFrames loaded.\n"
rag_context = self.rag_system.retrieve_relevant_info(user_query)
meaningful_rag_keywords = ["Error", "No valid", "No relevant", "Cannot retrieve", "not available", "not generated"]
is_meaningful_rag = bool(rag_context.strip()) and not any(keyword in rag_context for keyword in meaningful_rag_keywords)
if is_meaningful_rag: prompt += f"\n--- RAG CONTEXT ---\n{rag_context}\n"
else: prompt += "\n--- RAG CONTEXT ---\nNo specific RAG context found or RAG error.\n"
prompt += f"\n--- USER QUERY ---\n{user_query}\n"
if task_decomposition_hint: prompt += f"\n--- GUIDANCE ---\n{task_decomposition_hint}\n"
if cot_hint:
if self.pandas_llm.force_sandbox:
prompt += "\n--- PYTHON CODE GENERATION THOUGHT PROCESS ---\n"
prompt += "1. Goal? 2. Data sources (DFs, RAG)? 3. Analysis plan (comments)? 4. Write Python code. 5. CRITICAL: Formulate & `print()` textual insights. 6. Review. 7. Output ONLY ```python ... ```.\n"
else: # Not force_sandbox
prompt += "\n--- TEXTUAL RESPONSE THOUGHT PROCESS ---\n"
prompt += "1. Goal? 2. Data sources? 3. Formulate insights (data + RAG). 4. Structure: explanation, then insights.\n"
return prompt
async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
current_turn_history_for_llm = self.chat_history[:]
self.chat_history.append({"role": "user", "parts": [{"text": user_query}]})
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
logging.info(f"Built prompt for query: {user_query[:100]}...")
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
self.chat_history.append({"role": "model", "parts": [{"text": response_text}]})
MAX_HISTORY_TURNS = 5
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
logging.info(f"Chat history truncated.")
return response_text
def update_dataframes(self, new_dataframes: dict):
self.all_dataframes = new_dataframes if new_dataframes else {}
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
logging.info(f"Agent DataFrames updated. Schemas: {self.schemas_representation[:100]}...")
def clear_chat_history(self): self.chat_history = []; logging.info("Agent chat history cleared.")
# --- Example Usage (Conceptual) ---
async def main_test():
logging.info("Starting main_test for EmployerBrandingAgent...")
df_follower_stats = pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'country': ['USA'], 'new_followers': [10]})
df_posts = pd.DataFrame({'post_id': [1], 'theme': ['Culture'], 'engagement_rate': [0.05]})
test_dataframes = {"follower_stats": df_follower_stats, "posts": df_posts}
if not GEMINI_API_KEY: logging.warning("GEMINI_API_KEY not set. Testing with dummy functionality.")
agent = EmployerBrandingAgent(LLM_MODEL_NAME, GENERATION_CONFIG_PARAMS, DEFAULT_SAFETY_SETTINGS, test_dataframes, df_rag_documents, GEMINI_EMBEDDING_MODEL_NAME, force_sandbox=True)
queries = ["Which post theme has the highest average engagement rate? Provide an insight.", "Hello!"]
for query in queries:
logging.info(f"\n\n--- Query: {query} ---")
response = await agent.process_query(user_query=query)
logging.info(f"--- Response for '{query}': ---\n{response}\n---------------------------\n")
if GEMINI_API_KEY: await asyncio.sleep(1)
if __name__ == "__main__":
if GEMINI_API_KEY:
try: asyncio.run(main_test())
except RuntimeError as e:
if "asyncio.run() cannot be called from a running event loop" in str(e): print("Skip asyncio.run in existing loop.")
else: raise
else: print("GEMINI_API_KEY not set. Skipping main_test().")