Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +215 -558
eb_agent_module.py
CHANGED
@@ -9,10 +9,8 @@ import textwrap
|
|
9 |
|
10 |
# Attempt to import Google Generative AI and related types
|
11 |
try:
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
from google.genai import types as genai_types
|
16 |
# from google.generativeai import GenerationConfig # For direct use if needed
|
17 |
# from google.generativeai.types import HarmCategory, HarmBlockThreshold, SafetySetting # For direct use
|
18 |
|
@@ -26,14 +24,13 @@ except ImportError:
|
|
26 |
|
27 |
# Dummy Client and related structures
|
28 |
class Client:
|
29 |
-
def __init__(self, api_key=None):
|
30 |
self.api_key = api_key
|
31 |
-
self.models = self._Models()
|
32 |
-
print(f"Dummy genai.Client initialized {'with' if api_key else '
|
33 |
|
34 |
-
class _Models:
|
35 |
-
|
36 |
-
async def generate_content_async(model=None, contents=None, generation_config=None, safety_settings=None, stream=False): # Matched real signature better
|
37 |
print(f"Dummy genai.Client.models.generate_content_async called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}")
|
38 |
class DummyPart:
|
39 |
def __init__(self, text): self.text = text
|
@@ -44,20 +41,20 @@ except ImportError:
|
|
44 |
self.content = DummyContent()
|
45 |
self.finish_reason = genai_types.FinishReason.STOP # Use dummy FinishReason
|
46 |
self.safety_ratings = []
|
47 |
-
self.token_count = 0
|
48 |
-
self.index = 0
|
49 |
class DummyResponse:
|
50 |
def __init__(self):
|
51 |
self.candidates = [DummyCandidate()]
|
52 |
-
self.prompt_feedback = self._PromptFeedback()
|
53 |
-
self.text = "# Dummy response text from dummy client's async generate_content"
|
54 |
-
class _PromptFeedback:
|
55 |
def __init__(self):
|
56 |
self.block_reason = None
|
57 |
self.safety_ratings = []
|
58 |
return DummyResponse()
|
59 |
|
60 |
-
def generate_content(self, model=None, contents=None, generation_config=None, safety_settings=None, stream=False): # Matched real signature better
|
61 |
print(f"Dummy genai.Client.models.generate_content called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}")
|
62 |
# Re-using the async dummy structure for simplicity
|
63 |
class DummyPart:
|
@@ -67,14 +64,14 @@ except ImportError:
|
|
67 |
class DummyCandidate:
|
68 |
def __init__(self):
|
69 |
self.content = DummyContent()
|
70 |
-
self.finish_reason = genai_types.FinishReason.STOP
|
71 |
self.safety_ratings = []
|
72 |
self.token_count = 0
|
73 |
self.index = 0
|
74 |
class DummyResponse:
|
75 |
def __init__(self):
|
76 |
self.candidates = [DummyCandidate()]
|
77 |
-
self.prompt_feedback = self._PromptFeedback()
|
78 |
self.text = "# Dummy response text from dummy client's generate_content"
|
79 |
class _PromptFeedback:
|
80 |
def __init__(self):
|
@@ -83,134 +80,78 @@ except ImportError:
|
|
83 |
return DummyResponse()
|
84 |
|
85 |
@staticmethod
|
86 |
-
def GenerativeModel(model_name, generation_config=None, safety_settings=None, system_instruction=None): #
|
87 |
-
print(f"Dummy genai.GenerativeModel called for model: {model_name}
|
|
|
88 |
class DummyGenerativeModel:
|
89 |
def __init__(self, model_name_in, generation_config_in, safety_settings_in, system_instruction_in):
|
90 |
self.model_name = model_name_in
|
91 |
-
|
92 |
-
self.safety_settings = safety_settings_in
|
93 |
-
self.system_instruction = system_instruction_in
|
94 |
-
async def generate_content_async(self, contents, stream=False): # Matched real signature
|
95 |
-
print(f"Dummy GenerativeModel.generate_content_async called for {self.model_name}")
|
96 |
-
# Simplified response, similar to Client's dummy
|
97 |
-
class DummyPart:
|
98 |
-
def __init__(self, text): self.text = text
|
99 |
-
class DummyContent:
|
100 |
-
def __init__(self): self.parts = [DummyPart(f"# Dummy response from dummy GenerativeModel ({self.model_name})")]
|
101 |
-
class DummyCandidate:
|
102 |
-
def __init__(self):
|
103 |
-
self.content = DummyContent()
|
104 |
-
self.finish_reason = genai_types.FinishReason.STOP
|
105 |
-
self.safety_ratings = []
|
106 |
-
class DummyResponse:
|
107 |
-
def __init__(self):
|
108 |
-
self.candidates = [DummyCandidate()]
|
109 |
-
self.prompt_feedback = None
|
110 |
-
self.text = f"# Dummy response text from dummy GenerativeModel ({self.model_name})"
|
111 |
-
return DummyResponse()
|
112 |
-
|
113 |
-
def generate_content(self, contents, stream=False): # Matched real signature
|
114 |
-
print(f"Dummy GenerativeModel.generate_content called for {self.model_name}")
|
115 |
-
# Simplified response, similar to Client's dummy
|
116 |
class DummyPart:
|
117 |
def __init__(self, text): self.text = text
|
118 |
class DummyContent:
|
119 |
def __init__(self): self.parts = [DummyPart(f"# Dummy response from dummy GenerativeModel ({self.model_name})")]
|
120 |
class DummyCandidate:
|
121 |
def __init__(self):
|
122 |
-
self.content = DummyContent()
|
123 |
-
self.finish_reason = genai_types.FinishReason.STOP
|
124 |
-
self.safety_ratings = []
|
125 |
class DummyResponse:
|
126 |
def __init__(self):
|
127 |
-
self.candidates = [DummyCandidate()]
|
128 |
-
self.prompt_feedback = None
|
129 |
-
self.text = f"# Dummy response text from dummy GenerativeModel ({self.model_name})"
|
130 |
return DummyResponse()
|
131 |
-
|
132 |
return DummyGenerativeModel(model_name, generation_config, safety_settings, system_instruction)
|
133 |
|
|
|
134 |
@staticmethod
|
135 |
def embed_content(model, content, task_type, title=None):
|
136 |
print(f"Dummy genai.embed_content called for model: {model}, task_type: {task_type}, title: {title}")
|
137 |
-
# Ensure the dummy embedding matches typical dimensions (e.g., 768 for many models)
|
138 |
return {"embedding": [0.1] * 768}
|
139 |
|
140 |
class genai_types: # type: ignore
|
141 |
-
# Using dicts for dummy GenerationConfig and SafetySetting for simplicity
|
142 |
@staticmethod
|
143 |
-
def GenerationConfig(**kwargs):
|
144 |
print(f"Dummy genai_types.GenerationConfig created with: {kwargs}")
|
145 |
return dict(kwargs)
|
146 |
|
147 |
@staticmethod
|
148 |
def SafetySetting(category, threshold):
|
149 |
print(f"Dummy SafetySetting created: category={category}, threshold={threshold}")
|
150 |
-
return {"category": category, "threshold": threshold}
|
151 |
|
152 |
-
# Dummy Enums (can be simple string attributes)
|
153 |
class HarmCategory:
|
154 |
-
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
|
155 |
-
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
156 |
-
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
157 |
-
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
158 |
-
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
159 |
-
|
160 |
class HarmBlockThreshold:
|
161 |
-
BLOCK_NONE = "BLOCK_NONE"
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
class
|
167 |
-
|
168 |
-
STOP = "STOP"
|
169 |
-
MAX_TOKENS = "MAX_TOKENS"
|
170 |
SAFETY = "SAFETY"
|
171 |
-
RECITATION = "RECITATION"
|
172 |
OTHER = "OTHER"
|
173 |
|
174 |
-
# Placeholder for other types if needed by the script
|
175 |
-
# class BlockReason:
|
176 |
-
# SAFETY = "SAFETY"
|
177 |
-
|
178 |
-
|
179 |
# --- Configuration ---
|
180 |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
181 |
-
#
|
182 |
# LLM_MODEL_NAME = "gemini-2.0-flash" # Original
|
183 |
LLM_MODEL_NAME = "gemini-2.0-flash"
|
184 |
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
|
185 |
|
186 |
# Base generation configuration for the LLM
|
187 |
GENERATION_CONFIG_PARAMS = {
|
188 |
-
"temperature": 0.3,
|
189 |
"top_p": 1.0,
|
190 |
"top_k": 32,
|
191 |
-
"max_output_tokens": 8192,
|
192 |
-
# "candidate_count": 1, # Default is 1, explicitly setting it
|
193 |
}
|
194 |
|
195 |
# Default safety settings list for Gemini
|
196 |
try:
|
197 |
DEFAULT_SAFETY_SETTINGS = [
|
198 |
-
genai_types.SafetySetting(
|
199 |
-
|
200 |
-
|
201 |
-
),
|
202 |
-
genai_types.SafetySetting(
|
203 |
-
category=genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
204 |
-
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly
|
205 |
-
),
|
206 |
-
genai_types.SafetySetting(
|
207 |
-
category=genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
208 |
-
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly
|
209 |
-
),
|
210 |
-
genai_types.SafetySetting(
|
211 |
-
category=genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
212 |
-
threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, # Adjusted slightly
|
213 |
-
),
|
214 |
]
|
215 |
except AttributeError as e:
|
216 |
logging.warning(f"Could not define DEFAULT_SAFETY_SETTINGS using real genai_types: {e}. Using placeholder list of dicts.")
|
@@ -221,7 +162,6 @@ except AttributeError as e:
|
|
221 |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
222 |
]
|
223 |
|
224 |
-
|
225 |
# Logging setup
|
226 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(filename)s:%(lineno)d - %(message)s')
|
227 |
|
@@ -237,268 +177,166 @@ else:
|
|
237 |
|
238 |
# --- RAG Documents Definition (Example) ---
|
239 |
rag_documents_data = {
|
240 |
-
'Title': [
|
241 |
-
|
242 |
-
"Attracting Tech Talent in Competitive Markets",
|
243 |
-
"The Power of Employee Advocacy",
|
244 |
-
"Understanding Gen Z Workforce Expectations"
|
245 |
-
],
|
246 |
-
'Text': [
|
247 |
-
"Focus on authentic employee stories and showcase company culture. Highlight diversity and inclusion initiatives. Use video content for higher engagement. Clearly articulate your Employee Value Proposition (EVP).",
|
248 |
-
"Tech candidates value challenging projects, continuous learning opportunities, and a flexible work environment. Competitive compensation and modern tech stacks are crucial. Highlight your company's innovation and impact.",
|
249 |
-
"Encourage employees to share their positive experiences on social media. Provide them with shareable content and guidelines. Employee-generated content is often perceived as more trustworthy than corporate messaging.",
|
250 |
-
"Gen Z values purpose-driven work, transparency, mental health support, and opportunities for growth. They are digital natives and expect seamless online application processes. They also care deeply about social responsibility."
|
251 |
-
]
|
252 |
}
|
253 |
df_rag_documents = pd.DataFrame(rag_documents_data)
|
254 |
|
255 |
# --- Schema Representation ---
|
256 |
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str:
|
257 |
-
if not isinstance(df, pd.DataFrame):
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
schema_str = f"DataFrame 'df_{df_name}':\n"
|
263 |
-
schema_str += f" Columns: {df.columns.tolist()}\n"
|
264 |
-
schema_str += f" Shape: {df.shape}\n"
|
265 |
-
# Add dtypes for more clarity
|
266 |
-
# schema_str += " Data Types:\n"
|
267 |
-
# for col in df.columns:
|
268 |
-
# schema_str += f" {col}: {df[col].dtype}\n"
|
269 |
-
|
270 |
-
# Sample data (first 2 rows)
|
271 |
-
if not df.empty:
|
272 |
-
sample_str = df.head(2).to_string()
|
273 |
-
# Indent sample string for better readability in the prompt
|
274 |
-
indented_sample = "\n".join([" " + line for line in sample_str.splitlines()])
|
275 |
-
schema_str += f" Sample Data (first 2 rows):\n{indented_sample}\n"
|
276 |
-
else:
|
277 |
-
schema_str += " Sample Data: DataFrame is empty.\n"
|
278 |
return schema_str
|
279 |
|
280 |
def get_all_schemas_representation(dataframes_dict: dict) -> str:
|
281 |
-
if not dataframes_dict:
|
282 |
-
return "No DataFrames provided.\n"
|
283 |
return "".join(get_schema_representation(name, df) for name, df in dataframes_dict.items())
|
284 |
|
285 |
-
|
286 |
# --- Advanced RAG System ---
|
287 |
class AdvancedRAGSystem:
|
288 |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
|
289 |
self.embedding_model_name = embedding_model_name
|
290 |
self.documents_df = documents_df.copy()
|
291 |
self.embeddings_generated = False
|
292 |
-
|
|
|
293 |
|
294 |
if GEMINI_API_KEY and self.client_available:
|
295 |
try:
|
296 |
self._precompute_embeddings()
|
297 |
self.embeddings_generated = True
|
298 |
logging.info(f"RAG embeddings precomputed using '{self.embedding_model_name}'.")
|
299 |
-
except Exception as e:
|
300 |
-
logging.error(f"RAG precomputation error: {e}", exc_info=True)
|
301 |
else:
|
302 |
-
logging.warning(f"RAG embeddings not precomputed.
|
303 |
|
304 |
def _embed_fn(self, title: str, text: str) -> list[float]:
|
305 |
-
if not self.
|
306 |
-
# logging.debug(f"Skipping embedding for '{title}' as embeddings are not active.")
|
307 |
-
return [0.0] * 768 # Default dimension, adjust if your model differs
|
308 |
try:
|
309 |
-
# logging.debug(f"Embedding '{title}' with model '{self.embedding_model_name}'")
|
310 |
-
# Ensure content is not empty
|
311 |
content_to_embed = text if text else title
|
312 |
-
if not content_to_embed:
|
313 |
-
|
314 |
-
return [0.0] * 768
|
315 |
-
|
316 |
-
embedding_result = genai.embed_content(
|
317 |
-
model=self.embedding_model_name,
|
318 |
-
content=content_to_embed,
|
319 |
-
task_type="retrieval_document",
|
320 |
-
title=title if title else None # Pass title only if it exists
|
321 |
-
)
|
322 |
-
return embedding_result["embedding"]
|
323 |
except Exception as e:
|
324 |
logging.error(f"Error in _embed_fn for '{title}': {e}", exc_info=True)
|
325 |
return [0.0] * 768
|
326 |
|
327 |
def _precompute_embeddings(self):
|
328 |
-
if 'Embeddings' not in self.documents_df.columns:
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
if not
|
336 |
-
logging.warning("No content found in 'Text' or 'Title' columns to generate embeddings.")
|
337 |
-
return
|
338 |
-
|
339 |
-
self.documents_df.loc[mask, 'Embeddings'] = self.documents_df[mask].apply(
|
340 |
-
lambda row: self._embed_fn(row.get('Title', ''), row.get('Text', '')), axis=1
|
341 |
-
)
|
342 |
-
logging.info(f"Applied embedding function to {mask.sum()} rows.")
|
343 |
-
|
344 |
-
|
345 |
-
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str: # Increased top_k for more context
|
346 |
-
if not self.client_available:
|
347 |
-
return "\n[RAG Context]\nEmbedding client not available. Cannot retrieve RAG context.\n"
|
348 |
if not self.embeddings_generated or 'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all():
|
349 |
-
return "\n[RAG Context]\nEmbeddings not
|
350 |
-
|
351 |
try:
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
if
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
if query_embedding.shape[0] != document_embeddings.shape[1]:
|
368 |
-
logging.error(f"Embedding dimension mismatch. Query: {query_embedding.shape[0]}, Docs: {document_embeddings.shape[1]}")
|
369 |
-
return "\n[RAG Context]\nEmbedding dimension mismatch. Cannot calculate similarity.\n"
|
370 |
-
|
371 |
-
dot_products = np.dot(document_embeddings, query_embedding)
|
372 |
-
# Get indices of top_k largest dot products
|
373 |
-
# If fewer valid documents than top_k, take all of them
|
374 |
-
num_to_retrieve = min(top_k, len(valid_embeddings_df))
|
375 |
-
if num_to_retrieve == 0: # Should be caught by valid_embeddings_df.empty earlier
|
376 |
-
return "\n[RAG Context]\nNo relevant passages found (num_to_retrieve is 0).\n"
|
377 |
-
|
378 |
-
# Ensure indices are within bounds
|
379 |
-
idx = np.argsort(dot_products)[-num_to_retrieve:][::-1] # Top N, descending order
|
380 |
-
|
381 |
-
relevant_passages = ""
|
382 |
-
for i in idx:
|
383 |
-
if i < len(valid_embeddings_df): # Defensive check
|
384 |
-
doc = valid_embeddings_df.iloc[i]
|
385 |
-
relevant_passages += f"\n[RAG Context from: '{doc['Title']}']\n{doc['Text']}\n"
|
386 |
-
else:
|
387 |
-
logging.warning(f"Index {i} out of bounds for valid_embeddings_df (len {len(valid_embeddings_df)})")
|
388 |
-
|
389 |
-
|
390 |
-
return relevant_passages if relevant_passages else "\n[RAG Context]\nNo relevant passages found after similarity search.\n"
|
391 |
except Exception as e:
|
392 |
logging.error(f"Error in RAG retrieve_relevant_info: {e}", exc_info=True)
|
393 |
return f"\n[RAG Context]\nError during RAG retrieval: {type(e).__name__} - {e}\n"
|
394 |
|
395 |
-
|
396 |
-
# --- PandasLLM Class (Gemini-Powered) ---
|
397 |
class PandasLLM:
|
398 |
def __init__(self, llm_model_name: str,
|
399 |
generation_config_dict: dict,
|
400 |
safety_settings_list: list,
|
401 |
data_privacy=True, force_sandbox=True):
|
402 |
self.llm_model_name = llm_model_name
|
403 |
-
self.generation_config_dict = generation_config_dict
|
404 |
-
self.safety_settings_list = safety_settings_list
|
405 |
self.data_privacy = data_privacy
|
406 |
self.force_sandbox = force_sandbox
|
407 |
-
self.
|
408 |
-
|
409 |
-
if not GEMINI_API_KEY:
|
410 |
-
logging.warning(f"PandasLLM: GEMINI_API_KEY not set. Using dummy model if real 'genai' is not fully mocked.")
|
411 |
-
# Even if API key is not set, we might be using a dummy genai
|
412 |
-
# So, initialize the dummy model if genai.GenerativeModel is the dummy one
|
413 |
-
if hasattr(genai, 'GenerativeModel') and hasattr(genai.GenerativeModel, '__func__') and genai.GenerativeModel.__func__.__qualname__.startswith('genai.GenerativeModel'): # Heuristic for dummy
|
414 |
-
self.generative_model = genai.GenerativeModel(
|
415 |
-
model_name=self.llm_model_name,
|
416 |
-
generation_config=genai_types.GenerationConfig(**self.generation_config_dict) if self.generation_config_dict else None,
|
417 |
-
safety_settings=self.safety_settings_list
|
418 |
-
)
|
419 |
-
logging.info(f"PandasLLM: Initialized with DUMMY genai.GenerativeModel for '{self.llm_model_name}'.")
|
420 |
-
|
421 |
-
else: # GEMINI_API_KEY is set
|
422 |
-
try:
|
423 |
-
# Use genai_types.GenerationConfig for real API
|
424 |
-
config_for_model = genai_types.GenerationConfig(**self.generation_config_dict) if self.generation_config_dict else None
|
425 |
-
|
426 |
-
self.generative_model = genai.GenerativeModel(
|
427 |
-
model_name=self.llm_model_name, # The SDK handles the "models/" prefix
|
428 |
-
generation_config=config_for_model,
|
429 |
-
safety_settings=self.safety_settings_list
|
430 |
-
# system_instruction can be added here if needed globally for this model
|
431 |
-
)
|
432 |
-
logging.info(f"PandasLLM: Initialized with REAL genai.GenerativeModel for '{self.llm_model_name}'.")
|
433 |
-
except Exception as e:
|
434 |
-
logging.error(f"Failed to initialize PandasLLM with genai.GenerativeModel: {e}", exc_info=True)
|
435 |
-
# Fallback to dummy if real initialization fails, to prevent crashes
|
436 |
-
if hasattr(genai, 'GenerativeModel') and hasattr(genai.GenerativeModel, '__func__') and genai.GenerativeModel.__func__.__qualname__.startswith('genai.GenerativeModel'):
|
437 |
-
self.generative_model = genai.GenerativeModel(model_name=self.llm_model_name) # Basic dummy
|
438 |
-
logging.warning("PandasLLM: Falling back to DUMMY genai.GenerativeModel due to real initialization error.")
|
439 |
|
|
|
|
|
440 |
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
# Gemini API expects chat history in a specific format
|
447 |
-
# The 'contents' parameter should be a list of Content objects (dicts)
|
448 |
-
# For chat, this list often alternates between 'user' and 'model' roles.
|
449 |
-
# The final part of 'contents' should be the current user prompt.
|
450 |
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
gemini_history = []
|
453 |
if history:
|
454 |
for entry in history:
|
455 |
role = "model" if entry.get("role") == "assistant" else entry.get("role", "user")
|
456 |
gemini_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
|
457 |
|
458 |
-
# Add current prompt as the last user message
|
459 |
current_content = [{"role": "user", "parts": [{"text": prompt_text}]}]
|
|
|
460 |
|
461 |
-
#
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
#
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
contents_for_api = gemini_history + current_content # This forms the conversation
|
478 |
-
|
479 |
-
logging.info(f"\n--- Calling Gemini API (model: {self.llm_model_name}) ---\nContent (last part): {contents_for_api[-1]['parts'][0]['text'][:200]}...\n")
|
480 |
|
481 |
try:
|
482 |
-
|
483 |
-
|
484 |
-
response = await self.generative_model.generate_content_async(
|
485 |
contents=contents_for_api,
|
486 |
-
|
|
|
487 |
)
|
488 |
|
|
|
489 |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
|
490 |
hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
|
|
|
491 |
block_reason_val = response.prompt_feedback.block_reason
|
492 |
-
|
493 |
-
|
494 |
-
block_reason_str = genai_types.BlockedReason(block_reason_val).name
|
495 |
-
except:
|
496 |
-
block_reason_str = str(block_reason_val)
|
497 |
-
logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}. Ratings: {response.prompt_feedback.safety_ratings}")
|
498 |
return f"# Error: Prompt blocked by API. Reason: {block_reason_str}."
|
499 |
|
500 |
llm_output = ""
|
501 |
-
# Standard way to get text from Gemini response
|
502 |
if hasattr(response, 'text') and isinstance(response.text, str):
|
503 |
llm_output = response.text
|
504 |
elif response.candidates:
|
@@ -507,29 +345,14 @@ class PandasLLM:
|
|
507 |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
|
508 |
|
509 |
if not llm_output and candidate.finish_reason:
|
|
|
510 |
finish_reason_val = candidate.finish_reason
|
511 |
-
|
512 |
-
|
513 |
-
finish_reason_str = str(finish_reason_val) # Safer for now
|
514 |
-
# For real API, finish_reason is an enum member, so .name would work.
|
515 |
-
# For dummy, it might be a string already.
|
516 |
-
if hasattr(genai_types.FinishReason, '_enum_map_') and finish_reason_val in genai_types.FinishReason._enum_map_: # Check if it's a valid enum value
|
517 |
-
finish_reason_str = genai_types.FinishReason(finish_reason_val).name
|
518 |
-
|
519 |
-
except Exception as fre:
|
520 |
-
logging.debug(f"Could not get FinishReason name: {fre}")
|
521 |
-
finish_reason_str = str(finish_reason_val)
|
522 |
-
|
523 |
-
# Check if blocked due to safety
|
524 |
if finish_reason_str == "SAFETY": # or candidate.finish_reason == genai_types.FinishReason.SAFETY:
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
cat_name = rating.category.name if hasattr(rating.category, 'name') else str(rating.category)
|
529 |
-
prob_name = rating.probability.name if hasattr(rating.probability, 'name') else str(rating.probability)
|
530 |
-
safety_messages.append(f"Category: {cat_name}, Probability: {prob_name}")
|
531 |
-
logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}")
|
532 |
-
return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}"
|
533 |
|
534 |
logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
|
535 |
return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
|
@@ -537,17 +360,16 @@ class PandasLLM:
|
|
537 |
logging.error(f"Unexpected API response structure: {str(response)[:500]}")
|
538 |
return f"# Error: Unexpected API response structure: {str(response)[:200]}"
|
539 |
|
540 |
-
# logging.debug(f"LLM Raw Output:\n{llm_output}")
|
541 |
return llm_output
|
542 |
|
543 |
-
except genai_types.BlockedPromptException as bpe:
|
544 |
-
logging.error(f"Prompt
|
545 |
-
return f"# Error:
|
546 |
-
except genai_types.StopCandidateException as sce:
|
547 |
-
logging.error(f"Candidate
|
548 |
-
return f"# Error: Content generation
|
549 |
except Exception as e:
|
550 |
-
logging.error(f"Error calling Gemini API: {e}", exc_info=True)
|
551 |
return f"# Error during API call: {type(e).__name__} - {str(e)[:100]}."
|
552 |
|
553 |
|
@@ -556,75 +378,47 @@ class PandasLLM:
|
|
556 |
|
557 |
if self.force_sandbox:
|
558 |
code_to_execute = ""
|
559 |
-
# Robust code extraction
|
560 |
if "```python" in llm_response_text:
|
561 |
try:
|
562 |
-
# Standard ```python\nCODE\n```
|
563 |
code_block_match = llm_response_text.split("```python\n", 1)
|
564 |
-
if len(code_block_match) > 1:
|
565 |
-
|
566 |
-
else: # Try without newline after ```python
|
567 |
code_block_match = llm_response_text.split("```python", 1)
|
568 |
if len(code_block_match) > 1:
|
569 |
code_to_execute = code_block_match[1].split("```", 1)[0]
|
570 |
-
if code_to_execute.startswith("\n"):
|
571 |
-
|
572 |
-
|
573 |
-
except IndexError:
|
574 |
-
code_to_execute = "" # Should not happen with proper split logic
|
575 |
|
576 |
if llm_response_text.startswith("# Error:") or not code_to_execute.strip():
|
577 |
-
logging.warning(f"LLM
|
578 |
-
# If LLM returns an error or no code, pass that through directly.
|
579 |
-
# Or if it's a polite non-code refusal (e.g. "# Hello there! ...")
|
580 |
if not code_to_execute.strip() and not llm_response_text.startswith("# Error:"):
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
if "```" not in llm_response_text and len(llm_response_text.strip()) > 0: # Heuristic for non-code text
|
587 |
-
logging.info(f"LLM produced text output instead of Python code in sandbox mode. Passing through: {llm_response_text}")
|
588 |
-
# This might be desired if the LLM is explaining why it can't generate code.
|
589 |
-
return llm_response_text # Pass through LLM's direct response
|
590 |
-
return llm_response_text # Pass through LLM's error or its non-code (comment-only) response
|
591 |
-
|
592 |
-
logging.info(f"\n--- Code to Execute (extracted from LLM response): ---\n{code_to_execute}\n----------------------\n")
|
593 |
-
|
594 |
from io import StringIO
|
595 |
import sys
|
596 |
-
old_stdout = sys.stdout
|
597 |
-
sys.stdout = captured_output = StringIO()
|
598 |
-
|
599 |
-
# Prepare globals for exec. Prefix DataFrames with 'df_' as per prompt.
|
600 |
exec_globals = {'pd': pd, 'np': np}
|
601 |
if dataframes_dict:
|
602 |
for name, df_instance in dataframes_dict.items():
|
603 |
-
if isinstance(df_instance, pd.DataFrame):
|
604 |
-
|
605 |
-
else:
|
606 |
-
logging.warning(f"Item '{name}' in dataframes_dict is not a DataFrame. Skipping for exec_globals.")
|
607 |
-
|
608 |
try:
|
609 |
-
exec(code_to_execute, exec_globals, {})
|
610 |
-
final_output_str =
|
611 |
-
|
612 |
if not final_output_str.strip():
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
return "# LLM generated only comments or an empty code block. No output produced."
|
617 |
-
return "# Code executed successfully, but it did not produce any printed output. Please ensure the LLM's Python code includes print() statements for the desired results, insights, or answers."
|
618 |
return final_output_str
|
619 |
except Exception as e:
|
620 |
-
logging.error(f"Sandbox
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
sys.stdout = old_stdout
|
626 |
-
else: # Not force_sandbox
|
627 |
-
return llm_response_text
|
628 |
|
629 |
# --- Employer Branding Agent ---
|
630 |
class EmployerBrandingAgent:
|
@@ -636,233 +430,96 @@ class EmployerBrandingAgent:
|
|
636 |
embedding_model_name: str,
|
637 |
data_privacy=True, force_sandbox=True):
|
638 |
|
639 |
-
self.pandas_llm = PandasLLM(
|
640 |
-
llm_model_name,
|
641 |
-
generation_config_dict,
|
642 |
-
safety_settings_list,
|
643 |
-
data_privacy,
|
644 |
-
force_sandbox
|
645 |
-
)
|
646 |
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
|
647 |
self.all_dataframes = all_dataframes if all_dataframes else {}
|
648 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
649 |
self.chat_history = []
|
650 |
-
logging.info("EmployerBrandingAgent Initialized.")
|
651 |
|
652 |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
653 |
-
|
654 |
-
|
655 |
-
prompt = f"You are a highly skilled '{role}'. Your primary goal is to provide actionable employer branding insights and strategic recommendations by analyzing provided data (Pandas DataFrames) and contextual information (RAG documents).\n"
|
656 |
-
prompt += "You will be provided with schemas for available Pandas DataFrames and a user query.\n"
|
657 |
-
|
658 |
-
if self.pandas_llm.data_privacy:
|
659 |
-
prompt += "IMPORTANT: Adhere to data privacy. Do not output raw Personally Identifiable Information (PII) like individual names or specific user contact details. Summarize, aggregate, or anonymize data in your insights.\n"
|
660 |
|
661 |
if self.pandas_llm.force_sandbox:
|
662 |
prompt += "\n--- TASK: PYTHON CODE GENERATION FOR INSIGHTS ---\n"
|
663 |
-
prompt += "
|
664 |
-
prompt += "Output ONLY the Python code block
|
665 |
-
prompt += "
|
666 |
-
prompt += "Example of accessing a DataFrame: `df_follower_stats['country']`.\n"
|
667 |
-
|
668 |
prompt += "\n--- CRITICAL INSTRUCTIONS FOR PYTHON CODE OUTPUT ---\n"
|
669 |
-
prompt += "1. **Print Insights, Not Just Data:**
|
670 |
-
prompt += "
|
671 |
-
prompt += "
|
672 |
-
prompt += "2. **Synthesize with RAG
|
673 |
-
prompt += "3. **
|
674 |
-
prompt += "4. **Handle
|
675 |
-
prompt += "
|
676 |
-
|
677 |
-
prompt += " - For non-analytical queries (e.g., 'hello'), respond politely with a `print()` statement. Example: `print('Hello! How can I assist with your employer branding data analysis today?')`\n"
|
678 |
-
prompt += "5. **Function Usage:** If you define functions, ENSURE they are called and their results (or insights derived from them) are `print()`ed.\n"
|
679 |
-
prompt += "6. **DataFrame Naming:** Remember to use the `df_` prefix for DataFrame names in your code (e.g., `df_your_data`).\n"
|
680 |
-
|
681 |
-
else: # Not force_sandbox - LLM provides direct textual answer
|
682 |
prompt += "\n--- TASK: DIRECT TEXTUAL INSIGHT GENERATION ---\n"
|
683 |
-
prompt += "
|
684 |
|
685 |
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
|
686 |
-
if self.schemas_representation.strip()
|
687 |
-
prompt += "No specific DataFrames are currently loaded. Please rely on general knowledge and any provided RAG context for your response, or ask for data to be loaded.\n"
|
688 |
-
else:
|
689 |
-
prompt += self.schemas_representation
|
690 |
|
691 |
rag_context = self.rag_system.retrieve_relevant_info(user_query)
|
692 |
-
# Check if RAG context is meaningful before appending
|
693 |
meaningful_rag_keywords = ["Error", "No valid", "No relevant", "Cannot retrieve", "not available", "not generated"]
|
694 |
is_meaningful_rag = bool(rag_context.strip()) and not any(keyword in rag_context for keyword in meaningful_rag_keywords)
|
695 |
-
|
696 |
-
|
697 |
-
prompt += f"\n--- ADDITIONAL CONTEXT (from Employer Branding Knowledge Base - consider this for your insights) ---\n{rag_context}\n"
|
698 |
-
else:
|
699 |
-
prompt += "\n--- ADDITIONAL CONTEXT (from Employer Branding Knowledge Base) ---\nNo specific pre-defined context found highly relevant to this query, or RAG system encountered an issue. Rely on general knowledge and DataFrame analysis.\n"
|
700 |
-
|
701 |
|
702 |
prompt += f"\n--- USER QUERY ---\n{user_query}\n"
|
703 |
-
|
704 |
-
if task_decomposition_hint:
|
705 |
-
prompt += f"\n--- GUIDANCE FOR ANALYSIS (Task Decomposition) ---\n{task_decomposition_hint}\n"
|
706 |
|
707 |
if cot_hint:
|
708 |
if self.pandas_llm.force_sandbox:
|
709 |
-
prompt += "\n---
|
710 |
-
prompt += "1.
|
711 |
-
prompt += "2. **Identify Data Sources:** Which DataFrame(s) and column(s) are relevant? Is there RAG context to incorporate?\n"
|
712 |
-
prompt += "3. **Plan Analysis (Mental Outline / Code Comments):**\n"
|
713 |
-
prompt += " a. What calculations, aggregations, or transformations are needed?\n"
|
714 |
-
prompt += " b. How will RAG context be integrated into the final printed insight?\n"
|
715 |
-
prompt += " c. What is the exact textual insight/answer to be `print()`ed?\n"
|
716 |
-
prompt += "4. **Write Python Code:** Implement the plan. Use `df_name_of_dataframe`.\n"
|
717 |
-
prompt += "5. **CRITICAL - Formulate and `print()` Insights:** Construct the final textual insight(s) as strings and use `print()` statements for them. These prints are the agent's entire response. Ensure they are clear, actionable, and directly address the user's query, incorporating RAG if applicable.\n"
|
718 |
-
prompt += "6. **Review Code:** Check for correctness, clarity, and adherence to ALL instructions, especially the `print()` requirements for insightful text.\n"
|
719 |
-
prompt += "7. **Final Output:** Ensure ONLY the Python code block (```python...```) is generated.\n"
|
720 |
else: # Not force_sandbox
|
721 |
-
prompt += "\n---
|
722 |
-
prompt += "1.
|
723 |
-
prompt += "2. **Identify Data Sources:** Analyze the DataFrame schemas. Consider relevant RAG context.\n"
|
724 |
-
prompt += "3. **Formulate Insights:** Synthesize information from data and RAG to derive key insights and recommendations.\n"
|
725 |
-
prompt += "4. **Structure Response:** Provide a step-by-step explanation of your analysis, followed by the clear, actionable insights and strategic advice.\n"
|
726 |
-
|
727 |
return prompt
|
728 |
|
729 |
async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
current_turn_history_for_llm = self.chat_history[:] # History *before* this turn
|
734 |
-
|
735 |
-
self.chat_history.append({"role": "user", "parts": [{"text": user_query}]}) # Use 'parts' for Gemini
|
736 |
-
|
737 |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
|
738 |
-
|
739 |
-
# Log only a part of the prompt to avoid overly verbose logs
|
740 |
-
# logging.info(f"Full prompt to LLM (showing first 300 and last 300 chars for brevity):\n{full_prompt[:300]}...\n...\n{full_prompt[-300:]}")
|
741 |
-
logging.info(f"Built prompt for user query: {user_query[:100]}...")
|
742 |
-
|
743 |
-
|
744 |
-
# Pass the history *before* the current user query to the LLM
|
745 |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
MAX_HISTORY_TURNS = 5 # Each turn has a user and a model message
|
750 |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
|
751 |
-
# Keep the most recent turns. The history is [user1, model1, user2, model2,...]
|
752 |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
|
753 |
-
logging.info(f"Chat history truncated
|
754 |
-
|
755 |
return response_text
|
756 |
|
757 |
def update_dataframes(self, new_dataframes: dict):
|
758 |
self.all_dataframes = new_dataframes if new_dataframes else {}
|
759 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
760 |
-
logging.info(f"
|
761 |
-
# Potentially clear RAG embeddings if they depend on the old dataframes, or recompute.
|
762 |
-
# For now, RAG is independent of these dataframes.
|
763 |
|
764 |
-
def clear_chat_history(self):
|
765 |
-
self.chat_history = []
|
766 |
-
logging.info("EmployerBrandingAgent chat history cleared.")
|
767 |
|
768 |
-
# --- Example Usage (Conceptual
|
769 |
async def main_test():
|
770 |
logging.info("Starting main_test for EmployerBrandingAgent...")
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
'date': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-01', '2023-01-03']),
|
775 |
-
'country': ['USA', 'USA', 'Canada', 'UK'],
|
776 |
-
'new_followers': [10, 12, 5, 8]
|
777 |
-
}
|
778 |
-
df_follower_stats = pd.DataFrame(followers_data)
|
779 |
-
|
780 |
-
posts_data = {
|
781 |
-
'post_id': [1, 2, 3, 4],
|
782 |
-
'post_date': pd.to_datetime(['2023-01-01', '2023-01-01', '2023-01-02', '2023-01-03']),
|
783 |
-
'theme': ['Culture', 'Tech', 'Culture', 'Jobs'],
|
784 |
-
'impressions': [1000, 1500, 1200, 2000],
|
785 |
-
'engagements': [50, 90, 60, 120]
|
786 |
-
}
|
787 |
-
df_posts = pd.DataFrame(posts_data)
|
788 |
-
df_posts['engagement_rate'] = df_posts['engagements'] / df_posts['impressions']
|
789 |
-
|
790 |
-
test_dataframes = {
|
791 |
-
"follower_stats": df_follower_stats,
|
792 |
-
"posts": df_posts,
|
793 |
-
"empty_df": pd.DataFrame(), # Test empty df representation
|
794 |
-
"non_df_item": "This is not a dataframe" # Test non-df item
|
795 |
-
}
|
796 |
-
|
797 |
-
# Initialize the agent
|
798 |
-
# Ensure GEMINI_API_KEY is set in your environment for real calls
|
799 |
-
if not GEMINI_API_KEY:
|
800 |
-
logging.warning("GEMINI_API_KEY not found in environment. Testing with dummy/mocked functionality.")
|
801 |
-
|
802 |
-
agent = EmployerBrandingAgent(
|
803 |
-
llm_model_name=LLM_MODEL_NAME,
|
804 |
-
generation_config_dict=GENERATION_CONFIG_PARAMS,
|
805 |
-
safety_settings_list=DEFAULT_SAFETY_SETTINGS,
|
806 |
-
all_dataframes=test_dataframes,
|
807 |
-
rag_documents_df=df_rag_documents, # Using the example RAG data
|
808 |
-
embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME,
|
809 |
-
force_sandbox=True # Set to True to test code generation, False for direct LLM text
|
810 |
-
)
|
811 |
-
|
812 |
-
logging.info(f"Schema representation:\n{agent.schemas_representation}")
|
813 |
-
|
814 |
-
queries = [
|
815 |
-
"What are the key trends in follower growth by country based on the first few days of January 2023?",
|
816 |
-
"Which post theme has the highest average engagement rate? Provide an insight.",
|
817 |
-
"Hello there!",
|
818 |
-
"Can you tell me the average salary for software engineers? (This should state data is not available)",
|
819 |
-
"Summarize the best practices for attracting tech talent and combine it with an analysis of our top performing post themes."
|
820 |
-
]
|
821 |
|
|
|
|
|
|
|
|
|
822 |
for query in queries:
|
823 |
-
logging.info(f"\n\n---
|
824 |
response = await agent.process_query(user_query=query)
|
825 |
-
logging.info(f"---
|
826 |
-
|
827 |
-
if GEMINI_API_KEY: await asyncio.sleep(1)
|
828 |
-
|
829 |
-
# Test updating dataframes
|
830 |
-
new_posts_data = {
|
831 |
-
'post_id': [5, 6], 'post_date': pd.to_datetime(['2023-01-04', '2023-01-05']),
|
832 |
-
'theme': ['Innovation', 'Team'], 'impressions': [2500, 1800], 'engagements': [150, 100]
|
833 |
-
}
|
834 |
-
df_new_posts = pd.DataFrame(new_posts_data)
|
835 |
-
df_new_posts['engagement_rate'] = df_new_posts['engagements'] / df_new_posts['impressions']
|
836 |
-
|
837 |
-
updated_dataframes = {
|
838 |
-
"follower_stats": df_follower_stats, # unchanged
|
839 |
-
"posts": pd.concat([df_posts, df_new_posts]), # updated
|
840 |
-
"company_values": pd.DataFrame({'value': ['Innovation', 'Collaboration'], 'description': ['...', '...']}) # new df
|
841 |
-
}
|
842 |
-
agent.update_dataframes(updated_dataframes)
|
843 |
-
logging.info(f"\n--- Processing Query after DataFrame Update ---")
|
844 |
-
response_after_update = await agent.process_query("What's the latest top performing post theme now?")
|
845 |
-
logging.info(f"--- Agent Response for 'What's the latest top performing post theme now?': ---\n{response_after_update}\n---------------------------\n")
|
846 |
-
|
847 |
|
848 |
if __name__ == "__main__":
|
849 |
-
|
850 |
-
|
851 |
-
# Example: export GEMINI_API_KEY="your_api_key_here"
|
852 |
-
|
853 |
-
# To run the async main_test:
|
854 |
-
# asyncio.run(main_test())
|
855 |
-
# Or, if you're in a Jupyter environment that has its own loop:
|
856 |
-
# await main_test()
|
857 |
-
|
858 |
-
# For simplicity in a standard Python script:
|
859 |
-
if GEMINI_API_KEY: # Only run full async test if API key likely present
|
860 |
-
try:
|
861 |
-
asyncio.run(main_test())
|
862 |
except RuntimeError as e:
|
863 |
-
if "
|
864 |
-
|
865 |
-
|
866 |
-
raise
|
867 |
-
else:
|
868 |
-
print("GEMINI_API_KEY not set. Skipping main_test() which might make real API calls. The module can be imported and used elsewhere.")
|
|
|
9 |
|
10 |
# Attempt to import Google Generative AI and related types
|
11 |
try:
|
12 |
+
from google import generativeai as genai # Renamed for clarity to avoid conflict
|
13 |
+
from google.generativeai import types as genai_types
|
|
|
|
|
14 |
# from google.generativeai import GenerationConfig # For direct use if needed
|
15 |
# from google.generativeai.types import HarmCategory, HarmBlockThreshold, SafetySetting # For direct use
|
16 |
|
|
|
24 |
|
25 |
# Dummy Client and related structures
|
26 |
class Client:
|
27 |
+
def __init__(self, api_key=None): # api_key is optional for Client constructor
|
28 |
self.api_key = api_key
|
29 |
+
self.models = self._Models() # This is the service client for models
|
30 |
+
print(f"Dummy genai.Client initialized {'with api_key' if api_key else '(global API key expected)'}.")
|
31 |
|
32 |
+
class _Models: # Represents the model service client
|
33 |
+
async def generate_content_async(self, model=None, contents=None, generation_config=None, safety_settings=None, stream=False, tools=None, tool_config=None): # Matched real signature better
|
|
|
34 |
print(f"Dummy genai.Client.models.generate_content_async called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}")
|
35 |
class DummyPart:
|
36 |
def __init__(self, text): self.text = text
|
|
|
41 |
self.content = DummyContent()
|
42 |
self.finish_reason = genai_types.FinishReason.STOP # Use dummy FinishReason
|
43 |
self.safety_ratings = []
|
44 |
+
self.token_count = 0
|
45 |
+
self.index = 0
|
46 |
class DummyResponse:
|
47 |
def __init__(self):
|
48 |
self.candidates = [DummyCandidate()]
|
49 |
+
self.prompt_feedback = self._PromptFeedback()
|
50 |
+
self.text = "# Dummy response text from dummy client's async generate_content"
|
51 |
+
class _PromptFeedback:
|
52 |
def __init__(self):
|
53 |
self.block_reason = None
|
54 |
self.safety_ratings = []
|
55 |
return DummyResponse()
|
56 |
|
57 |
+
def generate_content(self, model=None, contents=None, generation_config=None, safety_settings=None, stream=False, tools=None, tool_config=None): # Matched real signature better
|
58 |
print(f"Dummy genai.Client.models.generate_content called for model: {model} with config: {generation_config}, safety_settings: {safety_settings}, stream: {stream}")
|
59 |
# Re-using the async dummy structure for simplicity
|
60 |
class DummyPart:
|
|
|
64 |
class DummyCandidate:
|
65 |
def __init__(self):
|
66 |
self.content = DummyContent()
|
67 |
+
self.finish_reason = genai_types.FinishReason.STOP
|
68 |
self.safety_ratings = []
|
69 |
self.token_count = 0
|
70 |
self.index = 0
|
71 |
class DummyResponse:
|
72 |
def __init__(self):
|
73 |
self.candidates = [DummyCandidate()]
|
74 |
+
self.prompt_feedback = self._PromptFeedback()
|
75 |
self.text = "# Dummy response text from dummy client's generate_content"
|
76 |
class _PromptFeedback:
|
77 |
def __init__(self):
|
|
|
80 |
return DummyResponse()
|
81 |
|
82 |
@staticmethod
|
83 |
+
def GenerativeModel(model_name, generation_config=None, safety_settings=None, system_instruction=None): # Kept for AdvancedRAGSystem if it uses it, or if user switches back
|
84 |
+
print(f"Dummy genai.GenerativeModel called for model: {model_name} (This might be unused if Client approach is preferred)")
|
85 |
+
# ... (rest of DummyGenerativeModel as before, for completeness) ...
|
86 |
class DummyGenerativeModel:
|
87 |
def __init__(self, model_name_in, generation_config_in, safety_settings_in, system_instruction_in):
|
88 |
self.model_name = model_name_in
|
89 |
+
async def generate_content_async(self, contents, stream=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
class DummyPart:
|
91 |
def __init__(self, text): self.text = text
|
92 |
class DummyContent:
|
93 |
def __init__(self): self.parts = [DummyPart(f"# Dummy response from dummy GenerativeModel ({self.model_name})")]
|
94 |
class DummyCandidate:
|
95 |
def __init__(self):
|
96 |
+
self.content = DummyContent(); self.finish_reason = genai_types.FinishReason.STOP; self.safety_ratings = []
|
|
|
|
|
97 |
class DummyResponse:
|
98 |
def __init__(self):
|
99 |
+
self.candidates = [DummyCandidate()]; self.prompt_feedback = None; self.text = f"# Dummy GM response"
|
|
|
|
|
100 |
return DummyResponse()
|
|
|
101 |
return DummyGenerativeModel(model_name, generation_config, safety_settings, system_instruction)
|
102 |
|
103 |
+
|
104 |
@staticmethod
|
105 |
def embed_content(model, content, task_type, title=None):
|
106 |
print(f"Dummy genai.embed_content called for model: {model}, task_type: {task_type}, title: {title}")
|
|
|
107 |
return {"embedding": [0.1] * 768}
|
108 |
|
109 |
class genai_types: # type: ignore
|
|
|
110 |
@staticmethod
|
111 |
+
def GenerationConfig(**kwargs):
|
112 |
print(f"Dummy genai_types.GenerationConfig created with: {kwargs}")
|
113 |
return dict(kwargs)
|
114 |
|
115 |
@staticmethod
|
116 |
def SafetySetting(category, threshold):
|
117 |
print(f"Dummy SafetySetting created: category={category}, threshold={threshold}")
|
118 |
+
return {"category": category, "threshold": threshold}
|
119 |
|
|
|
120 |
class HarmCategory:
|
121 |
+
HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"; HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"; HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"; HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"; HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
|
|
|
|
|
|
|
|
|
|
122 |
class HarmBlockThreshold:
|
123 |
+
BLOCK_NONE = "BLOCK_NONE"; BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"; BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"; BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
124 |
+
class FinishReason:
|
125 |
+
FINISH_REASON_UNSPECIFIED = "UNSPECIFIED"; STOP = "STOP"; MAX_TOKENS = "MAX_TOKENS"; SAFETY = "SAFETY"; RECITATION = "RECITATION"; OTHER = "OTHER"
|
126 |
+
|
127 |
+
# Dummy for BlockedReason if needed by response parsing
|
128 |
+
class BlockedReason:
|
129 |
+
BLOCKED_REASON_UNSPECIFIED = "BLOCKED_REASON_UNSPECIFIED"
|
|
|
|
|
130 |
SAFETY = "SAFETY"
|
|
|
131 |
OTHER = "OTHER"
|
132 |
|
|
|
|
|
|
|
|
|
|
|
133 |
# --- Configuration ---
|
134 |
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
|
135 |
+
# User-specified model names:
|
136 |
# LLM_MODEL_NAME = "gemini-2.0-flash" # Original
|
137 |
LLM_MODEL_NAME = "gemini-2.0-flash"
|
138 |
GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
|
139 |
|
140 |
# Base generation configuration for the LLM
|
141 |
GENERATION_CONFIG_PARAMS = {
|
142 |
+
"temperature": 0.3,
|
143 |
"top_p": 1.0,
|
144 |
"top_k": 32,
|
145 |
+
"max_output_tokens": 8192,
|
|
|
146 |
}
|
147 |
|
148 |
# Default safety settings list for Gemini
|
149 |
try:
|
150 |
DEFAULT_SAFETY_SETTINGS = [
|
151 |
+
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
|
152 |
+
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
|
153 |
+
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
|
154 |
+
genai_types.SafetySetting(category=genai_types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=genai_types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
]
|
156 |
except AttributeError as e:
|
157 |
logging.warning(f"Could not define DEFAULT_SAFETY_SETTINGS using real genai_types: {e}. Using placeholder list of dicts.")
|
|
|
162 |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
|
163 |
]
|
164 |
|
|
|
165 |
# Logging setup
|
166 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(filename)s:%(lineno)d - %(message)s')
|
167 |
|
|
|
177 |
|
178 |
# --- RAG Documents Definition (Example) ---
|
179 |
rag_documents_data = {
|
180 |
+
'Title': ["Employer Branding Best Practices 2024", "Attracting Tech Talent", "Employee Advocacy", "Gen Z Expectations"],
|
181 |
+
'Text': ["Focus on authentic employee stories...", "Tech candidates value challenging projects...", "Encourage employees to share experiences...", "Gen Z values purpose-driven work..."]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
}
|
183 |
df_rag_documents = pd.DataFrame(rag_documents_data)
|
184 |
|
185 |
# --- Schema Representation ---
|
186 |
def get_schema_representation(df_name: str, df: pd.DataFrame) -> str:
|
187 |
+
if not isinstance(df, pd.DataFrame): return f"Schema for item '{df_name}': Not a DataFrame.\n"
|
188 |
+
if df.empty: return f"Schema for DataFrame 'df_{df_name}': Empty.\n"
|
189 |
+
schema_str = f"DataFrame 'df_{df_name}':\n Columns: {df.columns.tolist()}\n Shape: {df.shape}\n"
|
190 |
+
if not df.empty: schema_str += f" Sample Data (first 2 rows):\n{textwrap.indent(df.head(2).to_string(), ' ')}\n"
|
191 |
+
else: schema_str += " Sample Data: DataFrame is empty.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
return schema_str
|
193 |
|
194 |
def get_all_schemas_representation(dataframes_dict: dict) -> str:
|
195 |
+
if not dataframes_dict: return "No DataFrames provided.\n"
|
|
|
196 |
return "".join(get_schema_representation(name, df) for name, df in dataframes_dict.items())
|
197 |
|
|
|
198 |
# --- Advanced RAG System ---
|
199 |
class AdvancedRAGSystem:
|
200 |
def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
|
201 |
self.embedding_model_name = embedding_model_name
|
202 |
self.documents_df = documents_df.copy()
|
203 |
self.embeddings_generated = False
|
204 |
+
# Check if genai.embed_content is the real one or our dummy
|
205 |
+
self.client_available = hasattr(genai, 'embed_content') and not (hasattr(genai.embed_content, '__func__') and genai.embed_content.__func__.__qualname__.startswith('genai.embed_content'))
|
206 |
|
207 |
if GEMINI_API_KEY and self.client_available:
|
208 |
try:
|
209 |
self._precompute_embeddings()
|
210 |
self.embeddings_generated = True
|
211 |
logging.info(f"RAG embeddings precomputed using '{self.embedding_model_name}'.")
|
212 |
+
except Exception as e: logging.error(f"RAG precomputation error: {e}", exc_info=True)
|
|
|
213 |
else:
|
214 |
+
logging.warning(f"RAG embeddings not precomputed. Key: {bool(GEMINI_API_KEY)}, embed_content_ok: {self.client_available}.")
|
215 |
|
216 |
def _embed_fn(self, title: str, text: str) -> list[float]:
|
217 |
+
if not self.client_available: return [0.0] * 768
|
|
|
|
|
218 |
try:
|
|
|
|
|
219 |
content_to_embed = text if text else title
|
220 |
+
if not content_to_embed: return [0.0] * 768
|
221 |
+
return genai.embed_content(model=self.embedding_model_name, content=content_to_embed, task_type="retrieval_document", title=title if title else None)["embedding"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
except Exception as e:
|
223 |
logging.error(f"Error in _embed_fn for '{title}': {e}", exc_info=True)
|
224 |
return [0.0] * 768
|
225 |
|
226 |
def _precompute_embeddings(self):
|
227 |
+
if 'Embeddings' not in self.documents_df.columns: self.documents_df['Embeddings'] = pd.Series(dtype='object')
|
228 |
+
mask = (self.documents_df['Text'].notna() & (self.documents_df['Text'] != '')) | (self.documents_df['Title'].notna() & (self.documents_df['Title'] != ''))
|
229 |
+
if not mask.any(): logging.warning("No content for RAG embeddings."); return
|
230 |
+
self.documents_df.loc[mask, 'Embeddings'] = self.documents_df[mask].apply(lambda row: self._embed_fn(row.get('Title', ''), row.get('Text', '')), axis=1)
|
231 |
+
logging.info(f"Applied RAG embedding function to {mask.sum()} rows.")
|
232 |
+
|
233 |
+
def retrieve_relevant_info(self, query_text: str, top_k: int = 2) -> str:
|
234 |
+
if not self.client_available: return "\n[RAG Context]\nEmbedding client not available.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
if not self.embeddings_generated or 'Embeddings' not in self.documents_df.columns or self.documents_df['Embeddings'].isnull().all():
|
236 |
+
return "\n[RAG Context]\nEmbeddings not ready for RAG.\n"
|
|
|
237 |
try:
|
238 |
+
query_embedding = np.array(genai.embed_content(model=self.embedding_model_name, content=query_text, task_type="retrieval_query")["embedding"])
|
239 |
+
valid_df = self.documents_df.dropna(subset=['Embeddings'])
|
240 |
+
valid_df = valid_df[valid_df['Embeddings'].apply(lambda x: isinstance(x, (list, np.ndarray)) and len(x) > 0)]
|
241 |
+
if valid_df.empty: return "\n[RAG Context]\nNo valid document embeddings.\n"
|
242 |
+
|
243 |
+
doc_embeddings = np.stack(valid_df['Embeddings'].apply(np.array).values)
|
244 |
+
if query_embedding.shape[0] != doc_embeddings.shape[1]: return "\n[RAG Context]\nEmbedding dimension mismatch.\n"
|
245 |
+
|
246 |
+
dot_products = np.dot(doc_embeddings, query_embedding)
|
247 |
+
num_to_retrieve = min(top_k, len(valid_df))
|
248 |
+
if num_to_retrieve == 0: return "\n[RAG Context]\nNo relevant passages found (num_to_retrieve is 0).\n"
|
249 |
+
|
250 |
+
idx = np.argsort(dot_products)[-num_to_retrieve:][::-1]
|
251 |
+
passages = "".join([f"\n[RAG Context from: '{valid_df.iloc[i]['Title']}']\n{valid_df.iloc[i]['Text']}\n" for i in idx if i < len(valid_df)])
|
252 |
+
return passages if passages else "\n[RAG Context]\nNo relevant passages found after search.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
except Exception as e:
|
254 |
logging.error(f"Error in RAG retrieve_relevant_info: {e}", exc_info=True)
|
255 |
return f"\n[RAG Context]\nError during RAG retrieval: {type(e).__name__} - {e}\n"
|
256 |
|
257 |
+
# --- PandasLLM Class (Gemini-Powered using genai.Client) ---
|
|
|
258 |
class PandasLLM:
|
259 |
def __init__(self, llm_model_name: str,
|
260 |
generation_config_dict: dict,
|
261 |
safety_settings_list: list,
|
262 |
data_privacy=True, force_sandbox=True):
|
263 |
self.llm_model_name = llm_model_name
|
264 |
+
self.generation_config_dict = generation_config_dict # Will be passed to API call
|
265 |
+
self.safety_settings_list = safety_settings_list # Will be passed to API call
|
266 |
self.data_privacy = data_privacy
|
267 |
self.force_sandbox = force_sandbox
|
268 |
+
self.client = None
|
269 |
+
self.model_service = None # This will be client.models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
+
# Check if genai.Client is the real one or our dummy
|
272 |
+
is_real_genai_client = hasattr(genai, 'Client') and not (hasattr(genai.Client, '__func__') and genai.Client.__func__.__qualname__.startswith('genai.Client'))
|
273 |
|
274 |
+
if not GEMINI_API_KEY and is_real_genai_client: # Real client but no API key
|
275 |
+
logging.warning(f"PandasLLM: GEMINI_API_KEY not set, but real 'genai.Client' seems available. API calls may fail if global config is not sufficient.")
|
276 |
+
# Proceed to initialize client; it might work if genai.configure() was successful without explicit key here
|
277 |
+
# or if the environment provides credentials in another way.
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
+
try:
|
280 |
+
self.client = genai.Client() # API key is usually set via genai.configure or environment
|
281 |
+
self.model_service = self.client.models
|
282 |
+
logging.info(f"PandasLLM: Initialized with genai.Client().models for '{self.llm_model_name}'.")
|
283 |
+
except Exception as e:
|
284 |
+
logging.error(f"Failed to initialize PandasLLM with genai.Client: {e}", exc_info=True)
|
285 |
+
# Fallback to dummy if real initialization fails, to prevent crashes
|
286 |
+
if not is_real_genai_client: # If it was already the dummy, re-initialize dummy
|
287 |
+
self.client = genai.Client()
|
288 |
+
self.model_service = self.client.models
|
289 |
+
logging.warning("PandasLLM: Falling back to DUMMY genai.Client due to real initialization error or it was already dummy.")
|
290 |
+
|
291 |
+
|
292 |
+
async def _call_gemini_api_async(self, prompt_text: str, history: list = None) -> str:
|
293 |
+
if not self.model_service:
|
294 |
+
logging.error("PandasLLM: Model service (client.models) not available. Cannot call API.")
|
295 |
+
return "# Error: Gemini model service not available for API call."
|
296 |
+
|
297 |
gemini_history = []
|
298 |
if history:
|
299 |
for entry in history:
|
300 |
role = "model" if entry.get("role") == "assistant" else entry.get("role", "user")
|
301 |
gemini_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
|
302 |
|
|
|
303 |
current_content = [{"role": "user", "parts": [{"text": prompt_text}]}]
|
304 |
+
contents_for_api = gemini_history + current_content
|
305 |
|
306 |
+
# Prepare model ID (e.g., "models/gemini-2.0-flash")
|
307 |
+
model_id_for_api = self.llm_model_name
|
308 |
+
if not model_id_for_api.startswith("models/"):
|
309 |
+
model_id_for_api = f"models/{model_id_for_api}"
|
310 |
+
|
311 |
+
# Prepare generation config object
|
312 |
+
api_generation_config = None
|
313 |
+
if self.generation_config_dict:
|
314 |
+
try:
|
315 |
+
api_generation_config = genai_types.GenerationConfig(**self.generation_config_dict)
|
316 |
+
except Exception as e_cfg:
|
317 |
+
logging.error(f"Error creating GenerationConfig object: {e_cfg}. Using dict as fallback.")
|
318 |
+
api_generation_config = self.generation_config_dict # Fallback to dict
|
319 |
+
|
320 |
+
logging.info(f"\n--- Calling Gemini API via Client (model: {model_id_for_api}) ---\nConfig: {api_generation_config}\nSafety: {bool(self.safety_settings_list)}\nContent (last part text): {contents_for_api[-1]['parts'][0]['text'][:100]}...\n")
|
|
|
|
|
|
|
|
|
321 |
|
322 |
try:
|
323 |
+
response = await self.model_service.generate_content_async(
|
324 |
+
model=model_id_for_api,
|
|
|
325 |
contents=contents_for_api,
|
326 |
+
generation_config=api_generation_config,
|
327 |
+
safety_settings=self.safety_settings_list
|
328 |
)
|
329 |
|
330 |
+
# ... (Response parsing logic remains largely the same as before) ...
|
331 |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
|
332 |
hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
|
333 |
+
# ... block reason handling ...
|
334 |
block_reason_val = response.prompt_feedback.block_reason
|
335 |
+
block_reason_str = str(block_reason_val.name if hasattr(block_reason_val, 'name') else block_reason_val)
|
336 |
+
logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}.")
|
|
|
|
|
|
|
|
|
337 |
return f"# Error: Prompt blocked by API. Reason: {block_reason_str}."
|
338 |
|
339 |
llm_output = ""
|
|
|
340 |
if hasattr(response, 'text') and isinstance(response.text, str):
|
341 |
llm_output = response.text
|
342 |
elif response.candidates:
|
|
|
345 |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
|
346 |
|
347 |
if not llm_output and candidate.finish_reason:
|
348 |
+
# ... finish reason handling ...
|
349 |
finish_reason_val = candidate.finish_reason
|
350 |
+
finish_reason_str = str(finish_reason_val.name if hasattr(finish_reason_val, 'name') else finish_reason_val)
|
351 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
if finish_reason_str == "SAFETY": # or candidate.finish_reason == genai_types.FinishReason.SAFETY:
|
353 |
+
# ... safety message handling ...
|
354 |
+
logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}.")
|
355 |
+
return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}."
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
|
358 |
return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
|
|
|
360 |
logging.error(f"Unexpected API response structure: {str(response)[:500]}")
|
361 |
return f"# Error: Unexpected API response structure: {str(response)[:200]}"
|
362 |
|
|
|
363 |
return llm_output
|
364 |
|
365 |
+
except genai_types.BlockedPromptException as bpe:
|
366 |
+
logging.error(f"Prompt blocked (BlockedPromptException): {bpe}", exc_info=True)
|
367 |
+
return f"# Error: Prompt blocked. Details: {bpe}"
|
368 |
+
except genai_types.StopCandidateException as sce:
|
369 |
+
logging.error(f"Candidate stopped (StopCandidateException): {sce}", exc_info=True)
|
370 |
+
return f"# Error: Content generation stopped. Details: {sce}"
|
371 |
except Exception as e:
|
372 |
+
logging.error(f"Error calling Gemini API via Client: {e}", exc_info=True)
|
373 |
return f"# Error during API call: {type(e).__name__} - {str(e)[:100]}."
|
374 |
|
375 |
|
|
|
378 |
|
379 |
if self.force_sandbox:
|
380 |
code_to_execute = ""
|
|
|
381 |
if "```python" in llm_response_text:
|
382 |
try:
|
|
|
383 |
code_block_match = llm_response_text.split("```python\n", 1)
|
384 |
+
if len(code_block_match) > 1: code_to_execute = code_block_match[1].split("\n```", 1)[0]
|
385 |
+
else:
|
|
|
386 |
code_block_match = llm_response_text.split("```python", 1)
|
387 |
if len(code_block_match) > 1:
|
388 |
code_to_execute = code_block_match[1].split("```", 1)[0]
|
389 |
+
if code_to_execute.startswith("\n"): code_to_execute = code_to_execute[1:]
|
390 |
+
except IndexError: code_to_execute = ""
|
|
|
|
|
|
|
391 |
|
392 |
if llm_response_text.startswith("# Error:") or not code_to_execute.strip():
|
393 |
+
logging.warning(f"LLM error or no code: {llm_response_text[:200]}")
|
|
|
|
|
394 |
if not code_to_execute.strip() and not llm_response_text.startswith("# Error:"):
|
395 |
+
if "```" not in llm_response_text and len(llm_response_text.strip()) > 0:
|
396 |
+
logging.info(f"LLM text output in sandbox mode: {llm_response_text[:200]}")
|
397 |
+
return llm_response_text
|
398 |
+
|
399 |
+
logging.info(f"\n--- Code to Execute: ---\n{code_to_execute}\n----------------------\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
400 |
from io import StringIO
|
401 |
import sys
|
402 |
+
old_stdout, sys.stdout = sys.stdout, StringIO()
|
|
|
|
|
|
|
403 |
exec_globals = {'pd': pd, 'np': np}
|
404 |
if dataframes_dict:
|
405 |
for name, df_instance in dataframes_dict.items():
|
406 |
+
if isinstance(df_instance, pd.DataFrame): exec_globals[f"df_{name}"] = df_instance
|
407 |
+
else: logging.warning(f"Item '{name}' not a DataFrame.")
|
|
|
|
|
|
|
408 |
try:
|
409 |
+
exec(code_to_execute, exec_globals, {})
|
410 |
+
final_output_str = sys.stdout.getvalue()
|
|
|
411 |
if not final_output_str.strip():
|
412 |
+
if not any(ln.strip() and not ln.strip().startswith("#") for ln in code_to_execute.splitlines()):
|
413 |
+
return "# LLM generated only comments or empty code. No output."
|
414 |
+
return "# Code executed, but no print() output. Ensure print() for results."
|
|
|
|
|
415 |
return final_output_str
|
416 |
except Exception as e:
|
417 |
+
logging.error(f"Sandbox Exec Error: {e}\nCode:\n{code_to_execute}", exc_info=True)
|
418 |
+
indented_code = textwrap.indent(code_to_execute, '# ')
|
419 |
+
return f"# Sandbox Exec Error: {type(e).__name__}: {e}\n# Code:\n{indented_code}"
|
420 |
+
finally: sys.stdout = old_stdout
|
421 |
+
else: return llm_response_text
|
|
|
|
|
|
|
422 |
|
423 |
# --- Employer Branding Agent ---
|
424 |
class EmployerBrandingAgent:
|
|
|
430 |
embedding_model_name: str,
|
431 |
data_privacy=True, force_sandbox=True):
|
432 |
|
433 |
+
self.pandas_llm = PandasLLM(llm_model_name, generation_config_dict, safety_settings_list, data_privacy, force_sandbox)
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
|
435 |
self.all_dataframes = all_dataframes if all_dataframes else {}
|
436 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
437 |
self.chat_history = []
|
438 |
+
logging.info("EmployerBrandingAgent Initialized (using Client API approach).")
|
439 |
|
440 |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
441 |
+
prompt = f"You are a highly skilled '{role}'. Your goal is to provide actionable employer branding insights by analyzing Pandas DataFrames and RAG documents.\n"
|
442 |
+
if self.pandas_llm.data_privacy: prompt += "IMPORTANT: Adhere to data privacy. Summarize/aggregate PII.\n"
|
|
|
|
|
|
|
|
|
|
|
443 |
|
444 |
if self.pandas_llm.force_sandbox:
|
445 |
prompt += "\n--- TASK: PYTHON CODE GENERATION FOR INSIGHTS ---\n"
|
446 |
+
prompt += "GENERATE PYTHON CODE using Pandas. The code's `print()` statements MUST output final textual insights/answers.\n"
|
447 |
+
prompt += "Output ONLY the Python code block (```python ... ```).\n"
|
448 |
+
prompt += "Access DataFrames as 'df_name' (e.g., `df_follower_stats`).\n"
|
|
|
|
|
449 |
prompt += "\n--- CRITICAL INSTRUCTIONS FOR PYTHON CODE OUTPUT ---\n"
|
450 |
+
prompt += "1. **Print Insights, Not Just Data:** `print()` clear, actionable insights. NOT raw DataFrames unless specifically asked for a table.\n"
|
451 |
+
prompt += " Good: `print(f'Insight: Theme {top_theme} has {engagement_increase}% higher engagement.')`\n"
|
452 |
+
prompt += " Avoid: `print(df_result)` (for insight queries).\n"
|
453 |
+
prompt += "2. **Synthesize with RAG:** Weave RAG takeaways into printed insights. Ex: `print(f'Data shows X. RAG says Y. Recommend Z.')`\n"
|
454 |
+
prompt += "3. **Comments & Clarity:** Write clean, commented code.\n"
|
455 |
+
prompt += "4. **Handle Issues in Code:** If ambiguous, `print()` a question. If data unavailable, `print()` explanation. For non-analytical queries, `print()` polite reply.\n"
|
456 |
+
prompt += "5. **Function Usage:** Call functions and `print()` their (insightful) results.\n"
|
457 |
+
else: # Not force_sandbox
|
|
|
|
|
|
|
|
|
|
|
458 |
prompt += "\n--- TASK: DIRECT TEXTUAL INSIGHT GENERATION ---\n"
|
459 |
+
prompt += "Analyze data and RAG, then provide a comprehensive textual answer with insights. Explain step-by-step.\n"
|
460 |
|
461 |
prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
|
462 |
+
prompt += self.schemas_representation if self.schemas_representation.strip() != "No DataFrames provided." else "No DataFrames loaded.\n"
|
|
|
|
|
|
|
463 |
|
464 |
rag_context = self.rag_system.retrieve_relevant_info(user_query)
|
|
|
465 |
meaningful_rag_keywords = ["Error", "No valid", "No relevant", "Cannot retrieve", "not available", "not generated"]
|
466 |
is_meaningful_rag = bool(rag_context.strip()) and not any(keyword in rag_context for keyword in meaningful_rag_keywords)
|
467 |
+
if is_meaningful_rag: prompt += f"\n--- RAG CONTEXT ---\n{rag_context}\n"
|
468 |
+
else: prompt += "\n--- RAG CONTEXT ---\nNo specific RAG context found or RAG error.\n"
|
|
|
|
|
|
|
|
|
469 |
|
470 |
prompt += f"\n--- USER QUERY ---\n{user_query}\n"
|
471 |
+
if task_decomposition_hint: prompt += f"\n--- GUIDANCE ---\n{task_decomposition_hint}\n"
|
|
|
|
|
472 |
|
473 |
if cot_hint:
|
474 |
if self.pandas_llm.force_sandbox:
|
475 |
+
prompt += "\n--- PYTHON CODE GENERATION THOUGHT PROCESS ---\n"
|
476 |
+
prompt += "1. Goal? 2. Data sources (DFs, RAG)? 3. Analysis plan (comments)? 4. Write Python code. 5. CRITICAL: Formulate & `print()` textual insights. 6. Review. 7. Output ONLY ```python ... ```.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
else: # Not force_sandbox
|
478 |
+
prompt += "\n--- TEXTUAL RESPONSE THOUGHT PROCESS ---\n"
|
479 |
+
prompt += "1. Goal? 2. Data sources? 3. Formulate insights (data + RAG). 4. Structure: explanation, then insights.\n"
|
|
|
|
|
|
|
|
|
480 |
return prompt
|
481 |
|
482 |
async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
483 |
+
current_turn_history_for_llm = self.chat_history[:]
|
484 |
+
self.chat_history.append({"role": "user", "parts": [{"text": user_query}]})
|
|
|
|
|
|
|
|
|
|
|
485 |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
|
486 |
+
logging.info(f"Built prompt for query: {user_query[:100]}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
487 |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
|
488 |
+
self.chat_history.append({"role": "model", "parts": [{"text": response_text}]})
|
489 |
+
MAX_HISTORY_TURNS = 5
|
|
|
|
|
490 |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
|
|
|
491 |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
|
492 |
+
logging.info(f"Chat history truncated.")
|
|
|
493 |
return response_text
|
494 |
|
495 |
def update_dataframes(self, new_dataframes: dict):
|
496 |
self.all_dataframes = new_dataframes if new_dataframes else {}
|
497 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
498 |
+
logging.info(f"Agent DataFrames updated. Schemas: {self.schemas_representation[:100]}...")
|
|
|
|
|
499 |
|
500 |
+
def clear_chat_history(self): self.chat_history = []; logging.info("Agent chat history cleared.")
|
|
|
|
|
501 |
|
502 |
+
# --- Example Usage (Conceptual) ---
|
503 |
async def main_test():
|
504 |
logging.info("Starting main_test for EmployerBrandingAgent...")
|
505 |
+
df_follower_stats = pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'country': ['USA'], 'new_followers': [10]})
|
506 |
+
df_posts = pd.DataFrame({'post_id': [1], 'theme': ['Culture'], 'engagement_rate': [0.05]})
|
507 |
+
test_dataframes = {"follower_stats": df_follower_stats, "posts": df_posts}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
+
if not GEMINI_API_KEY: logging.warning("GEMINI_API_KEY not set. Testing with dummy functionality.")
|
510 |
+
agent = EmployerBrandingAgent(LLM_MODEL_NAME, GENERATION_CONFIG_PARAMS, DEFAULT_SAFETY_SETTINGS, test_dataframes, df_rag_documents, GEMINI_EMBEDDING_MODEL_NAME, force_sandbox=True)
|
511 |
+
|
512 |
+
queries = ["Which post theme has the highest average engagement rate? Provide an insight.", "Hello!"]
|
513 |
for query in queries:
|
514 |
+
logging.info(f"\n\n--- Query: {query} ---")
|
515 |
response = await agent.process_query(user_query=query)
|
516 |
+
logging.info(f"--- Response for '{query}': ---\n{response}\n---------------------------\n")
|
517 |
+
if GEMINI_API_KEY: await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
|
519 |
if __name__ == "__main__":
|
520 |
+
if GEMINI_API_KEY:
|
521 |
+
try: asyncio.run(main_test())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
except RuntimeError as e:
|
523 |
+
if "asyncio.run() cannot be called from a running event loop" in str(e): print("Skip asyncio.run in existing loop.")
|
524 |
+
else: raise
|
525 |
+
else: print("GEMINI_API_KEY not set. Skipping main_test().")
|
|
|
|
|
|