Spaces:
Runtime error
Runtime error
# eb_agent_module.py | |
import pandas as pd | |
import json | |
import os | |
import asyncio | |
import logging | |
import numpy as np | |
import textwrap | |
# Attempt to import Google Generative AI and related types | |
try: | |
from google import genai | |
from google.genai import types as genai_types | |
except ImportError: | |
print("Google Generative AI library not found. Please install it: pip install google-generativeai") | |
# Define dummy classes/functions if the import fails, to allow the rest of the script to be parsed | |
class genai: # type: ignore | |
def configure(api_key): pass | |
# Making dummy Client return a dummy client object that has a dummy 'models' attribute | |
# which in turn has a dummy 'generate_content' method. | |
def Client(api_key=None): # api_key can be optional if configure is used | |
class DummyModels: | |
def generate_content(model=None, contents=None, generation_config=None, safety_settings=None): | |
print(f"Dummy genai.Client.models.generate_content called for model: {model}") | |
# Simulate a minimal valid-looking response structure | |
class DummyPart: | |
def __init__(self, text): | |
self.text = text | |
class DummyContent: | |
def __init__(self): | |
self.parts = [DummyPart("# Dummy response from dummy client")] | |
class DummyCandidate: | |
def __init__(self): | |
self.content = DummyContent() | |
self.finish_reason = "DUMMY" | |
self.safety_ratings = [] | |
class DummyResponse: | |
def __init__(self): | |
self.candidates = [DummyCandidate()] | |
self.prompt_feedback = None | |
def text(self): # Add a text property for compatibility | |
if self.candidates and self.candidates[0].content and self.candidates[0].content.parts: | |
return "".join(p.text for p in self.candidates[0].content.parts) | |
return "" | |
return DummyResponse() | |
class DummyClient: | |
def __init__(self): | |
self.models = DummyModels() | |
if api_key: # Only return a DummyClient if api_key is provided, mimicking real client | |
return DummyClient() | |
return None # If no API key, client init might fail or return None | |
def GenerativeModel(model_name): # Keep dummy GenerativeModel for other parts if any | |
print(f"Dummy genai.GenerativeModel called for model: {model_name}") | |
return None | |
def embed_content(model, content, task_type, title=None): | |
print(f"Dummy genai.embed_content called for model: {model}") | |
return {"embedding": [0.1] * 768} | |
class genai_types: # type: ignore | |
def GenerateContentConfig(**kwargs): return kwargs # Return the dict itself for dummy | |
class BlockReason: | |
SAFETY = "SAFETY" | |
class HarmCategory: | |
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED" | |
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" | |
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" | |
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" | |
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" | |
class HarmBlockThreshold: | |
BLOCK_NONE = "BLOCK_NONE" | |
# --- Configuration --- | |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "") | |
LLM_MODEL_NAME = "gemini-2.0-flash" # Updated model name | |
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07" # Updated embedding model name | |
# Generation configuration for the LLM | |
GENERATION_CONFIG_PARAMS = { | |
"temperature": 0.2, | |
"top_p": 1.0, | |
"top_k": 32, | |
"max_output_tokens": 4096, | |
} | |
# Safety settings for Gemini | |
# Ensure genai_types is the real one or the dummy has these attributes | |
try: | |
DEFAULT_SAFETY_SETTINGS = { | |
genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT: genai_types.HarmBlockThreshold.BLOCK_NONE, | |
genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: genai_types.HarmBlockThreshold.BLOCK_NONE, | |
genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: genai_types.HarmBlockThreshold.BLOCK_NONE, | |
genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: genai_types.HarmBlockThreshold.BLOCK_NONE, | |
} | |
except AttributeError: # If genai_types is the dummy and doesn't have these, create placeholder | |
logging.warning("Could not define DEFAULT_SAFETY_SETTINGS using genai_types. Using placeholder.") | |
DEFAULT_SAFETY_SETTINGS = {} | |
# Logging setup | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s') | |
# Configure Gemini API key globally if available | |
if GEMINI_API_KEY: | |
try: | |
genai.configure(api_key=GEMINI_API_KEY) | |
logging.info(f"Gemini API key configured globally. Target model for generation: '{LLM_MODEL_NAME}', Embedding model: '{GEMINI_EMBEDDING_MODEL_NAME}'") | |
except Exception as e: | |
logging.error(f"Failed to configure Gemini API globally: {e}", exc_info=True) | |
else: | |
logging.warning("GEMINI_API_KEY environment variable not set. LLM and Embedding functionalities will be limited.") | |
# --- RAG Documents Definition --- | |
rag_documents_data = { | |
'Title': [ | |
"Employer Branding Best Practices 2024", "Attracting Tech Talent", | |
"Understanding Company Culture", "Diversity and Inclusion in Hiring" | |
], | |
'Text': [ | |
"Focus on authentic employee stories...", "Tech candidates value challenging projects...", | |
"Company culture is defined by shared values...", "Promote diversity and inclusion by using inclusive language..." | |
] | |
} | |
df_rag_documents = pd.DataFrame(rag_documents_data) | |
# --- Schema Representation --- | |
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str: | |
if df.empty: | |
return f"Schema for DataFrame '{df_name}':\n - DataFrame is empty.\n" | |
cols = df.columns.tolist() | |
dtypes = df.dtypes.to_dict() | |
schema_str = f"Schema for DataFrame 'df_{df_name}':\n" | |
for col in cols: | |
schema_str += f" - Column '{col}': {dtypes[col]}\n" | |
for col in cols: | |
if 'date' in col.lower() or 'time' in col.lower(): | |
schema_str += f" - Note: Column '{col}' seems to be date/time related...\n" | |
if df[col].apply(type).eq(list).any() or df[col].apply(type).eq(dict).any(): | |
schema_str += f" - Note: Column '{col}' may contain list-like or dict-like data...\n" | |
if df[col].dtype == 'object' and df[col].nunique() < 20 and df.shape[0] > 20: | |
schema_str += f" - Note: Column '{col}' might be categorical...\n" | |
schema_str += f"Sample of first 2 rows of 'df_{df_name}':\n{df.head(2).to_string()}\n" | |
return schema_str | |
def get_all_schemas_representation(dataframes_dict: dict) -> str: | |
full_schema_str = "You have access to the following Pandas DataFrames...\n\n" | |
for name, df_instance in dataframes_dict.items(): | |
full_schema_str += get_schema_representation(name, df_instance) + "\n" | |
return full_schema_str | |
# --- Advanced RAG System --- | |
class AdvancedRAGSystem: | |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str): | |
self.embedding_model_name = embedding_model_name # Store the model name | |
if not GEMINI_API_KEY: | |
logging.warning("RAG System: GEMINI_API_KEY not set. Embeddings will not be generated.") | |
self.documents_df = documents_df.copy() | |
if 'Embeddings' not in self.documents_df.columns: | |
self.documents_df['Embeddings'] = pd.Series(dtype='object') | |
self.embeddings_generated = False | |
return | |
self.documents_df = documents_df.copy() | |
self.embeddings_generated = False | |
try: | |
# Check if genai.embed_content is available (not the dummy one) | |
if hasattr(genai, 'embed_content') and not (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')): # Basic check if it's not the dummy's staticmethod | |
self._precompute_embeddings() | |
self.embeddings_generated = True | |
logging.info("AdvancedRAGSystem Initialized and embeddings precomputed.") | |
else: | |
logging.warning("AdvancedRAGSystem: Real genai.embed_content not available. Skipping embedding precomputation.") | |
if 'Embeddings' not in self.documents_df.columns: | |
self.documents_df['Embeddings'] = pd.Series(dtype='object') | |
except Exception as e: | |
logging.error(f"Error during RAG embedding precomputation: {e}", exc_info=True) | |
if 'Embeddings' not in self.documents_df.columns: | |
self.documents_df['Embeddings'] = pd.Series(dtype='object') | |
def _embed_fn(self, title: str, text: str) -> list[float]: | |
try: | |
# Check if genai.embed_content is available and not the dummy's | |
if not self.embeddings_generated or not hasattr(genai, 'embed_content') or (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')): | |
logging.warning(f"genai.embed_content not available or using dummy. Returning zero vector for title: {title}") | |
return [0.0] * 768 # Default embedding size | |
embedding_result = genai.embed_content( | |
model=self.embedding_model_name, # Use the stored model name | |
content=text, | |
task_type="retrieval_document", | |
title=title | |
) | |
return embedding_result["embedding"] | |
except Exception as e: | |
logging.error(f"Error embedding content '{title}': {e}", exc_info=True) | |
return [0.0] * 768 | |
def _precompute_embeddings(self): | |
if 'Embeddings' not in self.documents_df.columns: | |
self.documents_df['Embeddings'] = pd.Series(dtype='object') | |
for index, row in self.documents_df.iterrows(): | |
current_embedding = row['Embeddings'] | |
is_valid_embedding = isinstance(current_embedding, list) and len(current_embedding) > 0 and sum(abs(x) for x in current_embedding) > 1e-6 | |
if not is_valid_embedding: | |
self.documents_df.at[index, 'Embeddings'] = self._embed_fn(row['Title'], row['Text']) | |
logging.info("Embeddings precomputation finished (or skipped if dummy).") | |
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str: | |
# Check if embeddings were actually generated and if the real embed_content is available | |
if not self.embeddings_generated or not hasattr(genai, 'embed_content') or \ | |
(hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content')) or \ | |
'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all(): | |
logging.warning("RAG System: Cannot retrieve info. Conditions not met (API key, embeddings, or real genai functions).") | |
return "\n[RAG Context]\nNo specific pre-defined context found (RAG system inactive or no embeddings).\n" | |
try: | |
query_embedding_result = genai.embed_content( | |
model=self.embedding_model_name, # Use the stored model name | |
content=query_text, | |
task_type="retrieval_query" | |
) | |
query_embedding = np.array(query_embedding_result["embedding"]) | |
valid_embeddings_df = self.documents_df.dropna(subset=['Embeddings']) | |
valid_embeddings_df = valid_embeddings_df[valid_embeddings_df['Embeddings'].apply(lambda x: isinstance(x, list) and len(x) > 0 and sum(abs(val) for val in x) > 1e-6)] | |
if valid_embeddings_df.empty: | |
return "\n[RAG Context]\nNo valid document embeddings available for retrieval.\n" | |
document_embeddings = np.stack(valid_embeddings_df['Embeddings'].apply(np.array).values) | |
if query_embedding.shape[0] != document_embeddings.shape[1]: | |
return "\n[RAG Context]\nEmbedding dimension mismatch.\n" | |
dot_products = np.dot(document_embeddings, query_embedding) | |
num_available_docs = len(valid_embeddings_df) | |
actual_top_k = min(top_k, num_available_docs) | |
if actual_top_k == 0: return "\n[RAG Context]\nNo documents to retrieve from.\n" | |
idx = [np.argmax(dot_products)] if actual_top_k == 1 and num_available_docs > 0 else (np.argsort(dot_products)[-actual_top_k:][::-1] if num_available_docs > 0 else []) | |
relevant_passages = "" | |
for i_val in idx: | |
passage_title = valid_embeddings_df.iloc[i_val]['Title'] | |
passage_text = valid_embeddings_df.iloc[i_val]['Text'] | |
relevant_passages += f"\n[RAG Context from: '{passage_title}']\n{passage_text}\n" | |
return relevant_passages if relevant_passages else "\n[RAG Context]\nNo highly relevant passages found.\n" | |
except Exception as e: | |
logging.error(f"Error retrieving relevant info from RAG: {e}", exc_info=True) | |
return f"\n[RAG Context]\nError during RAG retrieval: {str(e)}\n" | |
# --- PandasLLM Class (Gemini-Powered) --- | |
class PandasLLM: | |
def __init__(self, llm_model_name: str, generation_config_params: dict, | |
safety_settings: dict, # safety_settings might not be used by client.models.generate_content | |
data_privacy=True, force_sandbox=True): | |
self.llm_model_name = llm_model_name | |
self.generation_config_params = generation_config_params | |
self.safety_settings = safety_settings # Store it, might be usable | |
self.data_privacy = data_privacy | |
self.force_sandbox = force_sandbox | |
self.client = None | |
self.generative_model_service = None # To store client.models | |
if not GEMINI_API_KEY: | |
logging.warning("PandasLLM: GEMINI_API_KEY not set. LLM functionalities will be limited.") | |
else: | |
try: | |
# Global genai.configure should have been called already | |
# User's suggestion: client = genai.Client(api_key="GEMINI_API_KEY") | |
# If genai.configure was called, api_key might not be needed for genai.Client() | |
# However, to be safe and follow user's hint structure: | |
self.client = genai.Client(api_key=GEMINI_API_KEY) | |
if self.client and hasattr(self.client, 'models') and hasattr(self.client.models, 'generate_content'): | |
self.generative_model_service = self.client.models | |
logging.info(f"PandasLLM Initialized with genai.Client. Using client.models for '{self.llm_model_name}'.") | |
elif self.client and hasattr(self.client, 'generate_content'): # Fallback: client itself has generate_content | |
self.generative_model_service = self.client # Use client directly | |
logging.info(f"PandasLLM Initialized with genai.Client. Using client.generate_content for '{self.llm_model_name}'.") | |
else: | |
logging.warning(f"PandasLLM: genai.Client initialized, but suitable 'generate_content' method not found on client or client.models. LLM calls may fail.") | |
except AttributeError as ae: # Catch if genai.Client itself is missing (e.g. very old dummy or lib issue) | |
logging.error(f"Failed to initialize genai.Client: {ae}. The 'genai' module might be a dummy or library is missing/old.", exc_info=True) | |
except Exception as e: | |
logging.error(f"Failed to initialize PandasLLM with genai.Client: {e}", exc_info=True) | |
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str: | |
if not self.generative_model_service: | |
logging.error("PandasLLM: Generative model service (e.g., client.models or client) not initialized. Cannot call API.") | |
return "# Error: Gemini client or service not available. Check API key and library installation." | |
contents_for_api = [] | |
if history: | |
for entry in history: | |
role = entry.get("role", "user") | |
if role == "assistant": role = "model" | |
contents_for_api.append({"role": role, "parts": [{"text": entry.get("content", "")}]}) | |
contents_for_api.append({"role": "user", "parts": [{"text": prompt_text}]}) | |
generation_config_to_pass = self.generation_config_params | |
# For client.models.generate_content or client.generate_content, safety_settings might be a direct param | |
# or part of generation_config. This depends on the specific client API. | |
# Assuming it might be a direct parameter based on some Google API styles. | |
safety_settings_to_pass = self.safety_settings | |
logging.info(f"\n--- Calling Gemini API via Client with prompt (first 500 chars of last message): ---\n{contents_for_api[-1]['parts'][0]['text'][:500]}...\n-------------------------------------------------------\n") | |
try: | |
# Construct the model name string, usually 'models/model-name' | |
# self.llm_model_name is "gemini-2.0-flash", so "models/gemini-2.0-flash" | |
model_id_for_api = self.llm_model_name | |
if not model_id_for_api.startswith("models/"): | |
model_id_for_api = f"models/{model_id_for_api}" | |
# Try to call self.generative_model_service.generate_content | |
# This service could be client.models or client itself. | |
response = await asyncio.to_thread( | |
self.generative_model_service.generate_content, | |
model=model_id_for_api, | |
contents=contents_for_api, | |
generation_config=generation_config_to_pass, | |
safety_settings=safety_settings_to_pass | |
) | |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and response.prompt_feedback.block_reason: | |
reason = response.prompt_feedback.block_reason | |
reason_name = getattr(reason, 'name', str(reason)) | |
logging.warning(f"Gemini API call blocked by prompt feedback: {reason_name}") | |
return f"# Error: Prompt blocked due to content policy: {reason_name}." | |
llm_output = "" | |
if hasattr(response, 'text') and response.text: # Common for newer SDK responses | |
llm_output = response.text | |
elif hasattr(response, 'candidates') and response.candidates: | |
candidate = response.candidates[0] | |
if hasattr(candidate, 'content') and candidate.content and hasattr(candidate.content, 'parts') and candidate.content.parts: | |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text')) | |
if not llm_output and hasattr(candidate, 'finish_reason'): | |
finish_reason_val = candidate.finish_reason | |
finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val)) | |
logging.warning(f"No text content in response candidate. Finish reason: {finish_reason}") | |
if finish_reason == "SAFETY": | |
return f"# Error: Response generation stopped due to safety reasons ({finish_reason})." | |
elif finish_reason == "RECITATION": | |
return f"# Error: Response generation stopped due to recitation policy ({finish_reason})." | |
return f"# Error: The AI model returned an empty response. Finish reason: {finish_reason}." | |
else: | |
logging.warning(f"Gemini API response structure not recognized or empty. Response: {response}") | |
return "# Error: The AI model returned an unexpected or empty response structure." | |
logging.info(f"--- Gemini API Response (first 300 chars): ---\n{llm_output[:300]}...\n--------------------------------------------------\n") | |
return llm_output | |
except AttributeError as ae: | |
logging.error(f"AttributeError during Gemini client call: {ae}. This might indicate the client object or 'models' attribute doesn't have 'generate_content' or is None.", exc_info=True) | |
return f"# Error (Attribute): {type(ae).__name__} - {ae}. Check client structure." | |
except Exception as e: | |
logging.error(f"Error calling Gemini API via Client: {e}", exc_info=True) | |
if "API_KEY_INVALID" in str(e) or "API key not valid" in str(e): | |
return "# Error: Gemini API key is not valid." | |
if "PermissionDenied" in str(e) or "403" in str(e): | |
return f"# Error: Permission denied for model '{model_id_for_api}' or service." | |
# Check for model not found specifically | |
if ("not found" in str(e).lower() or "does not exist" in str(e).lower()) and model_id_for_api in str(e): | |
return f"# Error: Model '{model_id_for_api}' not found or not accessible with your API key via client." | |
return f"# Error: An unexpected error occurred while contacting the AI model via Client: {type(e).__name__}." | |
async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str: | |
llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history) | |
if self.force_sandbox: | |
code_to_execute = "" | |
if "```python" in llm_response_text: | |
try: | |
code_to_execute = llm_response_text.split("```python\n", 1)[1].split("\n```", 1)[0] | |
except IndexError: | |
try: | |
code_to_execute = llm_response_text.split("```python", 1)[1].split("```", 1)[0] | |
if code_to_execute.startswith("\n"): code_to_execute = code_to_execute[1:] | |
if code_to_execute.endswith("\n"): code_to_execute = code_to_execute[:-1] | |
except IndexError: | |
code_to_execute = "" | |
logging.warning("Could not extract Python code using primary or secondary split method.") | |
llm_response_text_for_sandbox_error = "" | |
if llm_response_text.startswith("# Error:") or not code_to_execute: | |
error_prefix = "LLM did not return valid Python code or an error occurred." | |
if llm_response_text.startswith("# Error:"): error_prefix = "An error occurred during LLM call." | |
elif not code_to_execute: error_prefix = "Could not extract Python code from LLM response." | |
safe_llm_response = str(llm_response_text).replace("'''", "'").replace('"""', '"') | |
llm_response_text_for_sandbox_error = f"print(f'''{error_prefix}\\nRaw LLM Response (may be truncated):\\n{safe_llm_response[:1000]}''')" | |
logging.warning(f"Problem with LLM response for sandbox: {error_prefix}") | |
logging.info(f"\n--- Code to Execute (from LLM, if sandbox): ---\n{code_to_execute}\n------------------------------------------------\n") | |
safe_builtins = {} | |
if isinstance(__builtins__, dict): | |
safe_builtins = {name: obj for name, obj in __builtins__.items() if not name.startswith('_')} | |
else: | |
safe_builtins = {name: obj for name, obj in __builtins__.__dict__.items() if not name.startswith('_')} | |
unsafe_builtins = ['eval', 'exec', 'open', 'compile', 'input', 'memoryview', 'vars', 'globals', 'locals', '__import__'] | |
for ub in unsafe_builtins: | |
safe_builtins.pop(ub, None) | |
exec_globals = {'pd': pd, 'np': np, '__builtins__': safe_builtins} | |
for name, df_instance in dataframes_dict.items(): | |
exec_globals[f"df_{name}"] = df_instance | |
from io import StringIO | |
import sys | |
old_stdout = sys.stdout | |
sys.stdout = captured_output = StringIO() | |
final_output_str = "" | |
try: | |
if code_to_execute: | |
exec(code_to_execute, exec_globals, {}) | |
output_val = captured_output.getvalue() | |
final_output_str = output_val if output_val else "# Code executed successfully, but no explicit print() output was generated by the code." | |
else: | |
exec(llm_response_text_for_sandbox_error, exec_globals, {}) | |
final_output_str = captured_output.getvalue() | |
except Exception as e: | |
error_msg = f"# Error executing LLM-generated code:\n# {type(e).__name__}: {str(e)}\n# --- Code that caused error: ---\n{textwrap.indent(code_to_execute, '# ')}" | |
final_output_str = error_msg | |
logging.error(error_msg, exc_info=False) | |
finally: | |
sys.stdout = old_stdout | |
return final_output_str | |
else: | |
return llm_response_text | |
# --- Employer Branding Agent --- | |
class EmployerBrandingAgent: | |
def __init__(self, llm_model_name: str, generation_config_params: dict, safety_settings: dict, | |
all_dataframes: dict, rag_documents_df: pd.DataFrame, embedding_model_name: str, | |
data_privacy=True, force_sandbox=True): | |
self.pandas_llm = PandasLLM(llm_model_name, generation_config_params, safety_settings, data_privacy, force_sandbox) | |
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name) | |
self.all_dataframes = all_dataframes | |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes) | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent Initialized.") | |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str: | |
prompt = f"You are a helpful and expert '{role}'...\n" # Truncated for brevity | |
# ... (rest of the prompt building logic remains the same) | |
prompt += "Your main task is to GENERATE PYTHON CODE using the Pandas library...\n" | |
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n" | |
prompt += self.schemas_representation | |
rag_context = self.rag_system.retrieve_relevant_info(user_query) | |
if rag_context and "[RAG Context]" in rag_context and "No specific pre-defined context found" not in rag_context and "No highly relevant passages found" not in rag_context: | |
prompt += f"\n--- ADDITIONAL CONTEXT (from internal knowledge base, consider this information) ---\n{rag_context}\n" | |
prompt += f"\n--- USER QUERY ---\n{user_query}\n" | |
if self.pandas_llm.force_sandbox: | |
prompt += "\n--- INSTRUCTIONS FOR PYTHON CODE GENERATION (Chain of Thought) ---\n" | |
prompt += "1. Understand the query...\n" | |
prompt += "7. Generate ONLY the Python code block starting with ```python and ending with ```...\n" | |
return prompt | |
async def process_query(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str: | |
logging.info(f"\n=== Processing Query for Role: {role}, Query: {user_query} ===") | |
self.chat_history.append({"role": "user", "content": user_query}) | |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint) | |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=self.chat_history[:-1]) | |
self.chat_history.append({"role": "assistant", "content": response_text}) | |
MAX_HISTORY_TURNS = 5 | |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2: | |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):] | |
return response_text | |
def update_dataframes(self, new_dataframes: dict): | |
self.all_dataframes = new_dataframes | |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes) | |
logging.info("EmployerBrandingAgent DataFrames updated.") | |
def clear_chat_history(self): | |
self.chat_history = [] | |
logging.info("EmployerBrandingAgent chat history cleared.") | |