Spaces:
Running
Running
# eb_agent_module.py | |
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap | |
# Attempt to import Google Generative AI and related types | |
try: | |
from google import generativeai as genai # Renamed for clarity to avoid conflict | |
from google.generativeai import types as genai_types | |
# from google.generativeai import GenerationConfig # For direct use if needed | |
# from google.generativeai.types import HarmCategory, HarmBlockThreshold, SafetySetting # For direct use | |
except ImportError: | |
print("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
# Define dummy classes/functions if the import fails, to allow the rest of the script to be parsed | |
class genai: # type: ignore | |
def configure(api_key): | |
print(f"Dummy genai.configure called with API key: {'SET' if api_key else 'NOT SET'}") | |
# Dummy Client and related structures | |
class Client: | |
def __init__(self, api_key=None): | |
self.api_key = api_key | |
self.models = self._Models() | |
print(f"Dummy genai.Client initialized {'with' if api_key else 'without'} API key.") | |
class _Models: | |
async def generate_content_async(model=None, contents=None, generation_config=None, safety_settings=None, stream=False): # Matched real signature better | |
print(f"Dummy genai.Client.models.generate_content_async called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}") | |
class DummyPart: | |
def __init__(self, text): self.text = text | |
class DummyContent: | |
def __init__(self): self.parts = [DummyPart("# Dummy response from dummy client's async generate_content")] | |
class DummyCandidate: | |
def __init__(self): | |
self.content = DummyContent() | |
self.finish_reason = genai_types.FinishReason.STOP # Use dummy FinishReason | |
self.safety_ratings = [] | |
self.token_count = 0 # Added | |
self.index = 0 # Added | |
class DummyResponse: | |
def __init__(self): | |
self.candidates = [DummyCandidate()] | |
self.prompt_feedback = self._PromptFeedback() # Use dummy PromptFeedback | |
self.text = "# Dummy response text from dummy client's async generate_content" # for easier access | |
class _PromptFeedback: # Nested dummy class | |
def __init__(self): | |
self.block_reason = None | |
self.safety_ratings = [] | |
return DummyResponse() | |
def generate_content(self, model=None, contents=None, generation_config=None, safety_settings=None, stream=False): # Matched real signature better | |
print(f"Dummy genai.Client.models.generate_content called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}") | |
# Re-using the async dummy structure for simplicity | |
class DummyPart: | |
def __init__(self, text): self.text = text | |
class DummyContent: | |
def __init__(self): self.parts = [DummyPart("# Dummy response from dummy client's generate_content")] | |
class DummyCandidate: | |
def __init__(self): | |
self.content = DummyContent() | |
self.finish_reason = genai_types.FinishReason.STOP # Use dummy FinishReason | |
self.safety_ratings = [] | |
self.token_count = 0 | |
self.index = 0 | |
class DummyResponse: | |
def __init__(self): | |
self.candidates = [DummyCandidate()] | |
self.prompt_feedback = self._PromptFeedback() # Use dummy PromptFeedback | |
self.text = "# Dummy response text from dummy client's generate_content" | |
class _PromptFeedback: | |
def __init__(self): | |
self.block_reason = None | |
self.safety_ratings = [] | |
return DummyResponse() | |
def GenerativeModel(model_name, generation_config=None, safety_settings=None, system_instruction=None): # Matched real signature | |
print(f"Dummy genai.GenerativeModel called for model: {model_name} with config: {generation_config}, safety: {safety_settings}, system_instruction: {system_instruction}") | |
class DummyGenerativeModel: | |
def __init__(self, model_name_in, generation_config_in, safety_settings_in, system_instruction_in): | |
self.model_name = model_name_in | |
self.generation_config = generation_config_in | |
self.safety_settings = safety_settings_in | |
self.system_instruction = system_instruction_in | |
async def generate_content_async(self, contents, stream=False): # Matched real signature | |
print(f"Dummy GenerativeModel.generate_content_async called for {self.model_name}") | |
# Simplified response, similar to Client's dummy | |
class DummyPart: | |
def __init__(self, text): self.text = text | |
class DummyContent: | |
def __init__(self): self.parts = [DummyPart(f"# Dummy response from dummy GenerativeModel ({self.model_name})")] | |
class DummyCandidate: | |
def __init__(self): | |
self.content = DummyContent() | |
self.finish_reason = genai_types.FinishReason.STOP | |
self.safety_ratings = [] | |
class DummyResponse: | |
def __init__(self): | |
self.candidates = [DummyCandidate()] | |
self.prompt_feedback = None | |
self.text = f"# Dummy response text from dummy GenerativeModel ({self.model_name})" | |
return DummyResponse() | |
def generate_content(self, contents, stream=False): # Matched real signature | |
print(f"Dummy GenerativeModel.generate_content called for {self.model_name}") | |
# Simplified response, similar to Client's dummy | |
class DummyPart: | |
def __init__(self, text): self.text = text | |
class DummyContent: | |
def __init__(self): self.parts = [DummyPart(f"# Dummy response from dummy GenerativeModel ({self.model_name})")] | |
class DummyCandidate: | |
def __init__(self): | |
self.content = DummyContent() | |
self.finish_reason = genai_types.FinishReason.STOP | |
self.safety_ratings = [] | |
class DummyResponse: | |
def __init__(self): | |
self.candidates = [DummyCandidate()] | |
self.prompt_feedback = None | |
self.text = f"# Dummy response text from dummy GenerativeModel ({self.model_name})" | |
return DummyResponse() | |
return DummyGenerativeModel(model_name, generation_config, safety_settings, system_instruction) | |
def embed_content(model, content, task_type, title=None): | |
print(f"Dummy genai.embed_content called for model: {model}, task_type: {task_type}, title: {title}") | |
# Ensure the dummy embedding matches typical dimensions (e.g., 768 for many models) | |
return {"embedding": [0.1] * 768} | |
class genai_types: # type: ignore | |
# Using dicts for dummy GenerationConfig and SafetySetting for simplicity | |
def GenerationConfig(**kwargs): # The dummy now just returns the kwargs as a dict | |
print(f"Dummy genai_types.GenerationConfig created with: {kwargs}") | |
return dict(kwargs) | |
def SafetySetting(category, threshold): | |
print(f"Dummy SafetySetting created: category={category}, threshold={threshold}") | |
return {"category": category, "threshold": threshold} # Return a dict for dummy | |
# Dummy Enums (can be simple string attributes) | |
class HarmCategory: | |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: | |
BLOCK_NONE = "BLOCK_NONE" | |
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" | |
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" | |
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" | |
class FinishReason: # Added dummy FinishReason | |
FINISH_REASON_UNSPECIFIED = "FINISH_REASON_UNSPECIFIED" | |
STOP = "STOP" | |
MAX_TOKENS = "MAX_TOKENS" | |
SAFETY = "SAFETY" | |
RECITATION = "RECITATION" | |
OTHER = "OTHER" | |
# Placeholder for other types if needed by the script | |
# class BlockReason: | |
# SAFETY = "SAFETY" | |
# --- Configuration --- | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
# Recommended: Use a standard, publicly available model name. | |
# LLM_MODEL_NAME = "gemini-2.0-flash" # Original | |
LLM_MODEL_NAME = "gemini-2.0-flash" | |
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" | |
# Base generation configuration for the LLM | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.3, # Slightly increased for more varied insights, adjust as needed | |
"top_p": 1.0, | |
"top_k": 32, | |
"max_output_tokens": 8192, # Increased for potentially longer code with comments and insights | |
# "candidate_count": 1, # Default is 1, explicitly setting it | |
} | |
# Default safety settings list for Gemini | |
try: | |
DEFAULT_SAFETY_SETTINGS = [ | |
genai_types.SafetySetting( | |
category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly | |
), | |
genai_types.SafetySetting( | |
category=genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly | |
), | |
genai_types.SafetySetting( | |
category=genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly | |
), | |
genai_types.SafetySetting( | |
category=genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly | |
), | |
] | |
except AttributeError as e: | |
logging.warning(f"Could not define DEFAULT_SAFETY_SETTINGS using real genai_types: {e}. Using placeholder list of dicts.") | |
DEFAULT_SAFETY_SETTINGS = [ | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"}, | |
] | |
# Logging setup | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(filename)s:%(lineno)d - %(message)s') | |
if GEMINI_API_KEY: | |
try: | |
genai.configure(api_key=GEMINI_API_KEY) | |
logging.info(f"Gemini API key configured globally.") | |
except Exception as e: | |
logging.error(f"Failed to configure Gemini API globally: {e}", exc_info=True) | |
else: | |
logging.warning("GEMINI_API_KEY environment variable not set. Agent will use dummy responses if real genai library is not fully mocked or if API calls fail.") | |
# --- RAG Documents Definition (Example) --- | |
rag_documents_data = { | |
'Title': [ | |
"Employer Branding Best Practices 2024", | |
"Attracting Tech Talent in Competitive Markets", | |
"The Power of Employee Advocacy", | |
"Understanding Gen Z Workforce Expectations" | |
], | |
'Text': [ | |
"Focus on authentic employee stories and showcase company culture. Highlight diversity and inclusion initiatives. Use video content for higher engagement. Clearly articulate your Employee Value Proposition (EVP).", | |
"Tech candidates value challenging projects, continuous learning opportunities, and a flexible work environment. Competitive compensation and modern tech stacks are crucial. Highlight your company's innovation and impact.", | |
"Encourage employees to share their positive experiences on social media. Provide them with shareable content and guidelines. Employee-generated content is often perceived as more trustworthy than corporate messaging.", | |
"Gen Z values purpose-driven work, transparency, mental health support, and opportunities for growth. They are digital natives and expect seamless online application processes. They also care deeply about social responsibility." | |
] | |
} | |
df_rag_documents = pd.DataFrame(rag_documents_data) | |
# --- Schema Representation --- | |
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str: | |
if not isinstance(df, pd.DataFrame): | |
return f"Schema for item '{df_name}': Not a DataFrame.\n" | |
if df.empty: | |
return f"Schema for DataFrame 'df_{df_name}': Empty (no columns or rows).\n" | |
schema_str = f"DataFrame 'df_{df_name}':\n" | |
schema_str += f" Columns: {df.columns.tolist()}\n" | |
schema_str += f" Shape: {df.shape}\n" | |
# Add dtypes for more clarity | |
# schema_str += " Data Types:\n" | |
# for col in df.columns: | |
# schema_str += f" {col}: {df[col].dtype}\n" | |
# Sample data (first 2 rows) | |
if not df.empty: | |
sample_str = df.head(2).to_string() | |
# Indent sample string for better readability in the prompt | |
indented_sample = "\n".join([" " + line for line in sample_str.splitlines()]) | |
schema_str += f" Sample Data (first 2 rows):\n{indented_sample}\n" | |
else: | |
schema_str += " Sample Data: DataFrame is empty.\n" | |
return schema_str | |
def get_all_schemas_representation(dataframes_dict: dict) -> str: | |
if not dataframes_dict: | |
return "No DataFrames provided.\n" | |
return "".join(get_schema_representation(name, df) for name, df in dataframes_dict.items()) | |
# --- Advanced RAG System --- | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.embedding_model_name = embedding_model_name | |
self.documents_df = documents_df.copy() | |
self.embeddings_generated = False | |
self.client_available = hasattr(genai, 'embed_content') and not (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')) # Check if it's not the dummy | |
if GEMINI_API_KEY and self.client_available: | |
try: | |
self._precompute_embeddings() | |
self.embeddings_generated = True | |
logging.info(f"RAG embeddings precomputed using '{self.embedding_model_name}'.") | |
except Exception as e: | |
logging.error(f"RAG precomputation error: {e}", exc_info=True) | |
else: | |
logging.warning(f"RAG embeddings not precomputed. GEMINI_API_KEY set: {bool(GEMINI_API_KEY)}, genai.embed_content available: {self.client_available}.") | |
def _embed_fn(self, title: str, text: str) -> list[float]: | |
if not self.embeddings_generated: # Should rely on self.client_available too | |
# logging.debug(f"Skipping embedding for '{title}' as embeddings are not active.") | |
return [0.0] * 768 # Default dimension, adjust if your model differs | |
try: | |
# logging.debug(f"Embedding '{title}' with model '{self.embedding_model_name}'") | |
# Ensure content is not empty | |
content_to_embed = text if text else title | |
if not content_to_embed: | |
logging.warning(f"Cannot embed '{title}' because both title and text are empty.") | |
return [0.0] * 768 | |
embedding_result = genai.embed_content( | |
model=self.embedding_model_name, | |
content=content_to_embed, | |
task_type="retrieval_document", | |
title=title if title else None # Pass title only if it exists | |
) | |
return embedding_result["embedding"] | |
except Exception as e: | |
logging.error(f"Error in _embed_fn for '{title}': {e}", exc_info=True) | |
return [0.0] * 768 | |
def _precompute_embeddings(self): | |
if 'Embeddings' not in self.documents_df.columns: | |
self.documents_df['Embeddings'] = pd.Series(dtype='object') | |
# Ensure there's text to embed | |
mask = (self.documents_df['Text'].notna() & (self.documents_df['Text'] != '')) | \ | |
(self.documents_df['Title'].notna() & (self.documents_df['Title'] != '')) | |
if not mask.any(): | |
logging.warning("No content found in 'Text' or 'Title' columns to generate embeddings.") | |
return | |
self.documents_df.loc[mask, 'Embeddings'] = self.documents_df[mask].apply( | |
lambda row: self._embed_fn(row.get('Title', ''), row.get('Text', '')), axis=1 | |
) | |
logging.info(f"Applied embedding function to {mask.sum()} rows.") | |
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str: # Increased top_k for more context | |
if not self.client_available: | |
return "\n[RAG Context]\nEmbedding client not available. Cannot retrieve RAG context.\n" | |
if not self.embeddings_generated or 'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all(): | |
return "\n[RAG Context]\nEmbeddings not generated or all are null. No RAG context available.\n" | |
try: | |
query_embedding_response = genai.embed_content( | |
model=self.embedding_model_name, | |
content=query_text, | |
task_type="retrieval_query" | |
) | |
query_embedding = np.array(query_embedding_response["embedding"]) | |
valid_embeddings_df = self.documents_df.dropna(subset=['Embeddings']) | |
valid_embeddings_df = valid_embeddings_df[valid_embeddings_df['Embeddings'].apply(lambda x: isinstance(x, (list, np.ndarray)) and len(x) > 0)] | |
if valid_embeddings_df.empty: | |
return "\n[RAG Context]\nNo valid document embeddings for RAG.\n" | |
document_embeddings = np.stack(valid_embeddings_df['Embeddings'].apply(np.array).values) | |
if query_embedding.shape[0] != document_embeddings.shape[1]: | |
logging.error(f"Embedding dimension mismatch. Query: {query_embedding.shape[0]}, Docs: {document_embeddings.shape[1]}") | |
return "\n[RAG Context]\nEmbedding dimension mismatch. Cannot calculate similarity.\n" | |
dot_products = np.dot(document_embeddings, query_embedding) | |
# Get indices of top_k largest dot products | |
# If fewer valid documents than top_k, take all of them | |
num_to_retrieve = min(top_k, len(valid_embeddings_df)) | |
if num_to_retrieve == 0: # Should be caught by valid_embeddings_df.empty earlier | |
return "\n[RAG Context]\nNo relevant passages found (num_to_retrieve is 0).\n" | |
# Ensure indices are within bounds | |
idx = np.argsort(dot_products)[-num_to_retrieve:][::-1] # Top N, descending order | |
relevant_passages = "" | |
for i in idx: | |
if i < len(valid_embeddings_df): # Defensive check | |
doc = valid_embeddings_df.iloc[i] | |
relevant_passages += f"\n[RAG Context from: '{doc['Title']}']\n{doc['Text']}\n" | |
else: | |
logging.warning(f"Index {i} out of bounds for valid_embeddings_df (len {len(valid_embeddings_df)})") | |
return relevant_passages if relevant_passages else "\n[RAG Context]\nNo relevant passages found after similarity search.\n" | |
except Exception as e: | |
logging.error(f"Error in RAG retrieve_relevant_info: {e}", exc_info=True) | |
return f"\n[RAG Context]\nError during RAG retrieval: {type(e).__name__} - {e}\n" | |
# --- PandasLLM Class (Gemini-Powered) --- | |
class PandasLLM: | |
def __init__(self, llm_model_name: str, | |
generation_config_dict: dict, | |
safety_settings_list: list, | |
data_privacy=True, force_sandbox=True): | |
self.llm_model_name = llm_model_name | |
self.generation_config_dict = generation_config_dict | |
self.safety_settings_list = safety_settings_list | |
self.data_privacy = data_privacy | |
self.force_sandbox = force_sandbox | |
self.generative_model = None # Will hold the GenerativeModel instance | |
if not GEMINI_API_KEY: | |
logging.warning(f"PandasLLM: GEMINI_API_KEY not set. Using dummy model if real 'genai' is not fully mocked.") | |
# Even if API key is not set, we might be using a dummy genai | |
# So, initialize the dummy model if genai.GenerativeModel is the dummy one | |
if hasattr(genai, 'GenerativeModel') and hasattr(genai.GenerativeModel, '__func__') and genai.GenerativeModel.__func__.__qualname__.startswith('genai.GenerativeModel'): # Heuristic for dummy | |
self.generative_model = genai.GenerativeModel( | |
model_name=self.llm_model_name, | |
generation_config=genai_types.GenerationConfig(**self.generation_config_dict) if self.generation_config_dict else None, | |
safety_settings=self.safety_settings_list | |
) | |
logging.info(f"PandasLLM: Initialized with DUMMY genai.GenerativeModel for '{self.llm_model_name}'.") | |
else: # GEMINI_API_KEY is set | |
try: | |
# Use genai_types.GenerationConfig for real API | |
config_for_model = genai_types.GenerationConfig(**self.generation_config_dict) if self.generation_config_dict else None | |
self.generative_model = genai.GenerativeModel( | |
model_name=self.llm_model_name, # The SDK handles the "models/" prefix | |
generation_config=config_for_model, | |
safety_settings=self.safety_settings_list | |
# system_instruction can be added here if needed globally for this model | |
) | |
logging.info(f"PandasLLM: Initialized with REAL genai.GenerativeModel for '{self.llm_model_name}'.") | |
except Exception as e: | |
logging.error(f"Failed to initialize PandasLLM with genai.GenerativeModel: {e}", exc_info=True) | |
# Fallback to dummy if real initialization fails, to prevent crashes | |
if hasattr(genai, 'GenerativeModel') and hasattr(genai.GenerativeModel, '__func__') and genai.GenerativeModel.__func__.__qualname__.startswith('genai.GenerativeModel'): | |
self.generative_model = genai.GenerativeModel(model_name=self.llm_model_name) # Basic dummy | |
logging.warning("PandasLLM: Falling back to DUMMY genai.GenerativeModel due to real initialization error.") | |
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str: | |
if not self.generative_model: | |
logging.error("PandasLLM: GenerativeModel not available (or not initialized). Cannot call API.") | |
return "# Error: Gemini model not available for API call." | |
# Gemini API expects chat history in a specific format | |
# The 'contents' parameter should be a list of Content objects (dicts) | |
# For chat, this list often alternates between 'user' and 'model' roles. | |
# The final part of 'contents' should be the current user prompt. | |
# Convert simple history to Gemini's expected format | |
gemini_history = [] | |
if history: | |
for entry in history: | |
role = "model" if entry.get("role") == "assistant" else entry.get("role", "user") | |
gemini_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]}) | |
# Add current prompt as the last user message | |
current_content = [{"role": "user", "parts": [{"text": prompt_text}]}] | |
# The 'contents' for generate_content should be the full conversation history + current prompt | |
# If there's history, it's usually passed to start_chat, but for one-off generate_content, | |
# we might need to construct it carefully. | |
# For non-chat models or simpler generate_content, just the prompt might be enough. | |
# The GenerativeModel().generate_content typically takes 'contents' which can be a string, | |
# a Part, or a list of Parts for the current turn. | |
# For chat-like behavior with generate_content, the 'contents' list should represent the conversation. | |
# Let's assume the prompt_text is the primary input for this turn. | |
# If the model is a chat model, the history needs to be managed carefully. | |
# For now, we'll pass the prompt_text as the main content for this turn. | |
# The `history` parameter here is for our internal tracking, | |
# the `GenerativeModel` might handle history differently (e.g. via `start_chat`). | |
# For a direct `generate_content` call, we typically provide the current turn's content. | |
# If we want to provide history, it should be part of the `contents` list. | |
contents_for_api = gemini_history + current_content # This forms the conversation | |
logging.info(f"\n--- Calling Gemini API (model: {self.llm_model_name}) ---\nContent (last part): {contents_for_api[-1]['parts'][0]['text'][:200]}...\n") | |
try: | |
# The GenerativeModel instance already has config and safety settings. | |
# We just pass the 'contents'. | |
response = await self.generative_model.generate_content_async( | |
contents=contents_for_api, | |
# generation_config, safety_settings are already part of self.generative_model | |
) | |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \ | |
hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason: | |
block_reason_val = response.prompt_feedback.block_reason | |
# Try to get enum name if available | |
try: | |
block_reason_str = genai_types.BlockedReason(block_reason_val).name | |
except: | |
block_reason_str = str(block_reason_val) | |
logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}. Ratings: {response.prompt_feedback.safety_ratings}") | |
return f"# Error: Prompt blocked by API. Reason: {block_reason_str}." | |
llm_output = "" | |
# Standard way to get text from Gemini response | |
if hasattr(response, 'text') and isinstance(response.text, str): | |
llm_output = response.text | |
elif response.candidates: | |
candidate = response.candidates[0] | |
if candidate.content and candidate.content.parts: | |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text')) | |
if not llm_output and candidate.finish_reason: | |
finish_reason_val = candidate.finish_reason | |
try: | |
# finish_reason_str = genai_types.FinishReason(finish_reason_val).name # This might fail if finish_reason_val is an int | |
finish_reason_str = str(finish_reason_val) # Safer for now | |
# For real API, finish_reason is an enum member, so .name would work. | |
# For dummy, it might be a string already. | |
if hasattr(genai_types.FinishReason, '_enum_map_') and finish_reason_val in genai_types.FinishReason._enum_map_: # Check if it's a valid enum value | |
finish_reason_str = genai_types.FinishReason(finish_reason_val).name | |
except Exception as fre: | |
logging.debug(f"Could not get FinishReason name: {fre}") | |
finish_reason_str = str(finish_reason_val) | |
# Check if blocked due to safety | |
if finish_reason_str == "SAFETY": # or candidate.finish_reason == genai_types.FinishReason.SAFETY: | |
safety_messages = [] | |
if candidate.safety_ratings: | |
for rating in candidate.safety_ratings: | |
cat_name = rating.category.name if hasattr(rating.category, 'name') else str(rating.category) | |
prob_name = rating.probability.name if hasattr(rating.probability, 'name') else str(rating.probability) | |
safety_messages.append(f"Category: {cat_name}, Probability: {prob_name}") | |
logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}") | |
return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}" | |
logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.") | |
return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}." | |
else: | |
logging.error(f"Unexpected API response structure: {str(response)[:500]}") | |
return f"# Error: Unexpected API response structure: {str(response)[:200]}" | |
# logging.debug(f"LLM Raw Output:\n{llm_output}") | |
return llm_output | |
except genai_types.BlockedPromptException as bpe: # Specific exception for blocked prompts | |
logging.error(f"Prompt was blocked by the API (BlockedPromptException): {bpe}", exc_info=True) | |
return f"# Error: Your prompt was blocked by the API. Please revise. Details: {bpe.prompt_feedback}" | |
except genai_types.StopCandidateException as sce: # Specific exception for candidate stopped | |
logging.error(f"Candidate generation stopped (StopCandidateException): {sce}", exc_info=True) | |
return f"# Error: Content generation was stopped. Details: {sce.candidate}" | |
except Exception as e: | |
logging.error(f"Error calling Gemini API: {e}", exc_info=True) | |
return f"# Error during API call: {type(e).__name__} - {str(e)[:100]}." | |
async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str: | |
llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history) | |
if self.force_sandbox: | |
code_to_execute = "" | |
# Robust code extraction | |
if "```python" in llm_response_text: | |
try: | |
# Standard ```python\nCODE\n``` | |
code_block_match = llm_response_text.split("```python\n", 1) | |
if len(code_block_match) > 1: | |
code_to_execute = code_block_match[1].split("\n```", 1)[0] | |
else: # Try without newline after ```python | |
code_block_match = llm_response_text.split("```python", 1) | |
if len(code_block_match) > 1: | |
code_to_execute = code_block_match[1].split("```", 1)[0] | |
if code_to_execute.startswith("\n"): # Remove leading newline if present | |
code_to_execute = code_to_execute[1:] | |
except IndexError: | |
code_to_execute = "" # Should not happen with proper split logic | |
if llm_response_text.startswith("# Error:") or not code_to_execute.strip(): | |
logging.warning(f"LLM response is an error, or no valid Python code block found. Raw LLM response: {llm_response_text}") | |
# If LLM returns an error or no code, pass that through directly. | |
# Or if it's a polite non-code refusal (e.g. "# Hello there! ...") | |
if not code_to_execute.strip() and not llm_response_text.startswith("# Error:"): | |
# This means LLM might have responded with natural language instead of code. | |
# If force_sandbox is true, we expect code. If it's not code, it's a deviation. | |
# However, the prompt allows for comments like "# Hello there..." | |
# So, if it's just comments, that's fine. | |
# If it's substantial text without code delimiters, that's an issue for sandbox mode. | |
if "```" not in llm_response_text and len(llm_response_text.strip()) > 0: # Heuristic for non-code text | |
logging.info(f"LLM produced text output instead of Python code in sandbox mode. Passing through: {llm_response_text}") | |
# This might be desired if the LLM is explaining why it can't generate code. | |
return llm_response_text # Pass through LLM's direct response | |
return llm_response_text # Pass through LLM's error or its non-code (comment-only) response | |
logging.info(f"\n--- Code to Execute (extracted from LLM response): ---\n{code_to_execute}\n----------------------\n") | |
from io import StringIO | |
import sys | |
old_stdout = sys.stdout | |
sys.stdout = captured_output = StringIO() | |
# Prepare globals for exec. Prefix DataFrames with 'df_' as per prompt. | |
exec_globals = {'pd': pd, 'np': np} | |
if dataframes_dict: | |
for name, df_instance in dataframes_dict.items(): | |
if isinstance(df_instance, pd.DataFrame): | |
exec_globals[f"df_{name}"] = df_instance | |
else: | |
logging.warning(f"Item '{name}' in dataframes_dict is not a DataFrame. Skipping for exec_globals.") | |
try: | |
exec(code_to_execute, exec_globals, {}) # Using empty dict for locals | |
final_output_str = captured_output.getvalue() | |
if not final_output_str.strip(): | |
logging.info("Code executed successfully, but no explicit print() output was generated by the LLM's code.") | |
# Check if the code was just comments or an empty block | |
if not any(line.strip() and not line.strip().startswith("#") for line in code_to_execute.splitlines()): | |
return "# LLM generated only comments or an empty code block. No output produced." | |
return "# Code executed successfully, but it did not produce any printed output. Please ensure the LLM's Python code includes print() statements for the desired results, insights, or answers." | |
return final_output_str | |
except Exception as e: | |
logging.error(f"Sandbox Execution Error: {e}\nCode was:\n{code_to_execute}", exc_info=True) # Log full traceback for sandbox error | |
# Indent the problematic code for better display in the error message | |
indented_code = textwrap.indent(code_to_execute, '# ', predicate=lambda line: True) | |
return f"# Sandbox Execution Error: {type(e).__name__}: {e}\n# --- Code that caused error: ---\n{indented_code}" | |
finally: | |
sys.stdout = old_stdout | |
else: # Not force_sandbox | |
return llm_response_text | |
# --- Employer Branding Agent --- | |
class EmployerBrandingAgent: | |
def __init__(self, llm_model_name: str, | |
generation_config_dict: dict, | |
safety_settings_list: list, | |
all_dataframes: dict, | |
rag_documents_df: pd.DataFrame, | |
embedding_model_name: str, | |
data_privacy=True, force_sandbox=True): | |
self.pandas_llm = PandasLLM( | |
llm_model_name, | |
generation_config_dict, | |
safety_settings_list, | |
data_privacy, | |
force_sandbox | |
) | |
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name) | |
self.all_dataframes = all_dataframes if all_dataframes else {} | |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes) | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent Initialized.") | |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str: | |
# System Instruction part of the prompt (can also be passed to GenerativeModel directly if API supports it well) | |
# This initial instruction sets the persona and overall goal. | |
prompt = f"You are a highly skilled '{role}'. Your primary goal is to provide actionable employer branding insights and strategic recommendations by analyzing provided data (Pandas DataFrames) and contextual information (RAG documents).\n" | |
prompt += "You will be provided with schemas for available Pandas DataFrames and a user query.\n" | |
if self.pandas_llm.data_privacy: | |
prompt += "IMPORTANT: Adhere to data privacy. Do not output raw Personally Identifiable Information (PII) like individual names or specific user contact details. Summarize, aggregate, or anonymize data in your insights.\n" | |
if self.pandas_llm.force_sandbox: | |
prompt += "\n--- TASK: PYTHON CODE GENERATION FOR INSIGHTS ---\n" | |
prompt += "Your main task is to GENERATE PYTHON CODE. This code should use the Pandas library to analyze the provided DataFrames and incorporate insights from any RAG context. The code's `print()` statements MUST output the final textual insights, analyses, or answers to the user's query.\n" | |
prompt += "Output ONLY the Python code block, starting with ```python and ending with ```.\n" | |
prompt += "The available DataFrames are already loaded and can be accessed by their dictionary keys prefixed with 'df_' (e.g., df_follower_stats, df_posts) within the execution environment.\n" | |
prompt += "Example of accessing a DataFrame: `df_follower_stats['country']`.\n" | |
prompt += "\n--- CRITICAL INSTRUCTIONS FOR PYTHON CODE OUTPUT ---\n" | |
prompt += "1. **Print Insights, Not Just Data:** Your Python code's `print()` statements are the agent's final response. These prints MUST articulate clear, actionable insights or answers. Do NOT just print raw DataFrames or intermediate variables unless the query *specifically* asks for a table of data (e.g., 'Show me the first 5 posts').\n" | |
prompt += " Example of good insight print: `print(f'Key Insight: Content related to {top_theme} receives {avg_engagement_increase}% higher engagement, suggesting a focus on this area.')`\n" | |
prompt += " Example of what to AVOID for insight queries: `print(df_analysis_result)` (unless df_analysis_result is the specific table requested).\n" | |
prompt += "2. **Synthesize with RAG Context:** If RAG context is provided, weave takeaways from it into your printed insights. Example: `print(f'Data shows X. Combined with RAG best practice Y, we recommend Z.')`\n" | |
prompt += "3. **Structure and Comments:** Write clean, commented Python code. Explain your logic for each step.\n" | |
prompt += "4. **Handle Ambiguity/Errors in Code:**\n" | |
prompt += " - If the query is ambiguous, `print()` a clarifying question as a string. Do not generate analysis code.\n" | |
prompt += " - If the query cannot be answered with the given data/schemas, `print()` a statement explaining this. Example: `print('Insight: Cannot determine X as the required data Y is not available in the provided DataFrames.')`\n" | |
prompt += " - For non-analytical queries (e.g., 'hello'), respond politely with a `print()` statement. Example: `print('Hello! How can I assist with your employer branding data analysis today?')`\n" | |
prompt += "5. **Function Usage:** If you define functions, ENSURE they are called and their results (or insights derived from them) are `print()`ed.\n" | |
prompt += "6. **DataFrame Naming:** Remember to use the `df_` prefix for DataFrame names in your code (e.g., `df_your_data`).\n" | |
else: # Not force_sandbox - LLM provides direct textual answer | |
prompt += "\n--- TASK: DIRECT TEXTUAL INSIGHT GENERATION ---\n" | |
prompt += "Your task is to analyze the data (described by schemas) and RAG context, then provide a comprehensive textual answer with actionable insights and strategic recommendations. Explain your reasoning step-by-step.\n" | |
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n" | |
if self.schemas_representation.strip() == "No DataFrames provided.": | |
prompt += "No specific DataFrames are currently loaded. Please rely on general knowledge and any provided RAG context for your response, or ask for data to be loaded.\n" | |
else: | |
prompt += self.schemas_representation | |
rag_context = self.rag_system.retrieve_relevant_info(user_query) | |
# Check if RAG context is meaningful before appending | |
meaningful_rag_keywords = ["Error", "No valid", "No relevant", "Cannot retrieve", "not available", "not generated"] | |
is_meaningful_rag = bool(rag_context.strip()) and not any(keyword in rag_context for keyword in meaningful_rag_keywords) | |
if is_meaningful_rag: | |
prompt += f"\n--- ADDITIONAL CONTEXT (from Employer Branding Knowledge Base - consider this for your insights) ---\n{rag_context}\n" | |
else: | |
prompt += "\n--- ADDITIONAL CONTEXT (from Employer Branding Knowledge Base) ---\nNo specific pre-defined context found highly relevant to this query, or RAG system encountered an issue. Rely on general knowledge and DataFrame analysis.\n" | |
prompt += f"\n--- USER QUERY ---\n{user_query}\n" | |
if task_decomposition_hint: | |
prompt += f"\n--- GUIDANCE FOR ANALYSIS (Task Decomposition) ---\n{task_decomposition_hint}\n" | |
if cot_hint: | |
if self.pandas_llm.force_sandbox: | |
prompt += "\n--- THOUGHT PROCESS FOR PYTHON CODE GENERATION (Follow these steps) ---\n" | |
prompt += "1. **Understand Query & Goal:** What specific employer branding insight or answer is the user seeking?\n" | |
prompt += "2. **Identify Data Sources:** Which DataFrame(s) and column(s) are relevant? Is there RAG context to incorporate?\n" | |
prompt += "3. **Plan Analysis (Mental Outline / Code Comments):**\n" | |
prompt += " a. What calculations, aggregations, or transformations are needed?\n" | |
prompt += " b. How will RAG context be integrated into the final printed insight?\n" | |
prompt += " c. What is the exact textual insight/answer to be `print()`ed?\n" | |
prompt += "4. **Write Python Code:** Implement the plan. Use `df_name_of_dataframe`.\n" | |
prompt += "5. **CRITICAL - Formulate and `print()` Insights:** Construct the final textual insight(s) as strings and use `print()` statements for them. These prints are the agent's entire response. Ensure they are clear, actionable, and directly address the user's query, incorporating RAG if applicable.\n" | |
prompt += "6. **Review Code:** Check for correctness, clarity, and adherence to ALL instructions, especially the `print()` requirements for insightful text.\n" | |
prompt += "7. **Final Output:** Ensure ONLY the Python code block (```python...```) is generated.\n" | |
else: # Not force_sandbox | |
prompt += "\n--- THOUGHT PROCESS FOR DIRECT TEXTUAL RESPONSE (Follow these steps) ---\n" | |
prompt += "1. **Understand Query & Goal:** What specific employer branding insight or answer is the user seeking?\n" | |
prompt += "2. **Identify Data Sources:** Analyze the DataFrame schemas. Consider relevant RAG context.\n" | |
prompt += "3. **Formulate Insights:** Synthesize information from data and RAG to derive key insights and recommendations.\n" | |
prompt += "4. **Structure Response:** Provide a step-by-step explanation of your analysis, followed by the clear, actionable insights and strategic advice.\n" | |
return prompt | |
async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str: | |
# Add user query to history before building prompt, so RAG can use the latest query | |
# However, the LLM call itself should get history *excluding* the current query in its history part. | |
current_turn_history_for_llm = self.chat_history[:] # History *before* this turn | |
self.chat_history.append({"role": "user", "parts": [{"text": user_query}]}) # Use 'parts' for Gemini | |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint) | |
# Log only a part of the prompt to avoid overly verbose logs | |
# logging.info(f"Full prompt to LLM (showing first 300 and last 300 chars for brevity):\n{full_prompt[:300]}...\n...\n{full_prompt[-300:]}") | |
logging.info(f"Built prompt for user query: {user_query[:100]}...") | |
# Pass the history *before* the current user query to the LLM | |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm) | |
self.chat_history.append({"role": "model", "parts": [{"text": response_text}]}) # Use 'parts' for Gemini | |
MAX_HISTORY_TURNS = 5 # Each turn has a user and a model message | |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2: | |
# Keep the most recent turns. The history is [user1, model1, user2, model2,...] | |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):] | |
logging.info(f"Chat history truncated to last {MAX_HISTORY_TURNS} turns.") | |
return response_text | |
def update_dataframes(self, new_dataframes: dict): | |
self.all_dataframes = new_dataframes if new_dataframes else {} | |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes) | |
logging.info(f"EmployerBrandingAgent DataFrames updated. New schemas: {self.schemas_representation[:200]}...") | |
# Potentially clear RAG embeddings if they depend on the old dataframes, or recompute. | |
# For now, RAG is independent of these dataframes. | |
def clear_chat_history(self): | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent chat history cleared.") | |
# --- Example Usage (Conceptual - for testing the module structure) --- | |
async def main_test(): | |
logging.info("Starting main_test for EmployerBrandingAgent...") | |
# Dummy DataFrames for testing | |
followers_data = { | |
'date': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-01', '2023-01-03']), | |
'country': ['USA', 'USA', 'Canada', 'UK'], | |
'new_followers': [10, 12, 5, 8] | |
} | |
df_follower_stats = pd.DataFrame(followers_data) | |
posts_data = { | |
'post_id': [1, 2, 3, 4], | |
'post_date': pd.to_datetime(['2023-01-01', '2023-01-01', '2023-01-02', '2023-01-03']), | |
'theme': ['Culture', 'Tech', 'Culture', 'Jobs'], | |
'impressions': [1000, 1500, 1200, 2000], | |
'engagements': [50, 90, 60, 120] | |
} | |
df_posts = pd.DataFrame(posts_data) | |
df_posts['engagement_rate'] = df_posts['engagements'] / df_posts['impressions'] | |
test_dataframes = { | |
"follower_stats": df_follower_stats, | |
"posts": df_posts, | |
"empty_df": pd.DataFrame(), # Test empty df representation | |
"non_df_item": "This is not a dataframe" # Test non-df item | |
} | |
# Initialize the agent | |
# Ensure GEMINI_API_KEY is set in your environment for real calls | |
if not GEMINI_API_KEY: | |
logging.warning("GEMINI_API_KEY not found in environment. Testing with dummy/mocked functionality.") | |
agent = EmployerBrandingAgent( | |
llm_model_name=LLM_MODEL_NAME, | |
generation_config_dict=GENERATION_CONFIG_PARAMS, | |
safety_settings_list=DEFAULT_SAFETY_SETTINGS, | |
all_dataframes=test_dataframes, | |
rag_documents_df=df_rag_documents, # Using the example RAG data | |
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME, | |
force_sandbox=True # Set to True to test code generation, False for direct LLM text | |
) | |
logging.info(f"Schema representation:\n{agent.schemas_representation}") | |
queries = [ | |
"What are the key trends in follower growth by country based on the first few days of January 2023?", | |
"Which post theme has the highest average engagement rate? Provide an insight.", | |
"Hello there!", | |
"Can you tell me the average salary for software engineers? (This should state data is not available)", | |
"Summarize the best practices for attracting tech talent and combine it with an analysis of our top performing post themes." | |
] | |
for query in queries: | |
logging.info(f"\n\n--- Processing Query: {query} ---") | |
response = await agent.process_query(user_query=query) | |
logging.info(f"--- Agent Response for '{query}': ---\n{response}\n---------------------------\n") | |
# Small delay if making actual API calls to avoid rate limits during testing | |
if GEMINI_API_KEY: await asyncio.sleep(1) | |
# Test updating dataframes | |
new_posts_data = { | |
'post_id': [5, 6], 'post_date': pd.to_datetime(['2023-01-04', '2023-01-05']), | |
'theme': ['Innovation', 'Team'], 'impressions': [2500, 1800], 'engagements': [150, 100] | |
} | |
df_new_posts = pd.DataFrame(new_posts_data) | |
df_new_posts['engagement_rate'] = df_new_posts['engagements'] / df_new_posts['impressions'] | |
updated_dataframes = { | |
"follower_stats": df_follower_stats, # unchanged | |
"posts": pd.concat([df_posts, df_new_posts]), # updated | |
"company_values": pd.DataFrame({'value': ['Innovation', 'Collaboration'], 'description': ['...', '...']}) # new df | |
} | |
agent.update_dataframes(updated_dataframes) | |
logging.info(f"\n--- Processing Query after DataFrame Update ---") | |
response_after_update = await agent.process_query("What's the latest top performing post theme now?") | |
logging.info(f"--- Agent Response for 'What's the latest top performing post theme now?': ---\n{response_after_update}\n---------------------------\n") | |
if __name__ == "__main__": | |
# This allows running the test if the script is executed directly | |
# Note: For real API calls, ensure GEMINI_API_KEY is set in your environment. | |
# Example: export GEMINI_API_KEY="your_api_key_here" | |
# To run the async main_test: | |
# asyncio.run(main_test()) | |
# Or, if you're in a Jupyter environment that has its own loop: | |
# await main_test() | |
# For simplicity in a standard Python script: | |
if GEMINI_API_KEY: # Only run full async test if API key likely present | |
try: | |
asyncio.run(main_test()) | |
except RuntimeError as e: | |
if " asyncio.run() cannot be called from a running event loop" in str(e): | |
print("Skipping asyncio.run(main_test()) as it seems to be in an existing event loop (e.g., Jupyter). Call 'await main_test()' instead if appropriate.") | |
else: | |
raise | |
else: | |
print("GEMINI_API_KEY not set. Skipping main_test() which might make real API calls. The module can be imported and used elsewhere.") | |