GuglielmoTor commited on
Commit
c9f7ea0
·
verified ·
1 Parent(s): f1d603c

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +37 -68
eb_agent_module.py CHANGED
@@ -274,21 +274,12 @@ class PandasLLM:
274
  logging.error("PandasLLM: Gemini model not initialized. Cannot call API.")
275
  return "# Error: Gemini model not available. Check API key and configuration."
276
 
277
- # Construct content for Gemini API
278
- # The new API expects a list of Content objects, or a list of dicts
279
- # For chat-like interaction, history should be managed.
280
- # For single-turn code generation, a simple user prompt might suffice.
281
-
282
- # For now, let's assume single turn for code generation for simplicity in PandasLLM context
283
- # If this were a conversational agent, history would be crucial.
284
  contents_for_api = [{"role": "user", "parts": [{"text": prompt_text}]}]
285
- if history: # If history is provided, prepend it
286
- # Ensure history is in the correct format [{'role':'user/model', 'parts':[{'text':...}]}]
287
- # This part might need adjustment based on how history is structured by the calling agent
288
  formatted_history = []
289
  for entry in history:
290
- role = entry.get("role", "user") # Default to user if role not specified
291
- if role == "assistant": role = "model" # Gemini uses "model" for assistant
292
  formatted_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
293
  contents_for_api = formatted_history + contents_for_api
294
 
@@ -302,21 +293,18 @@ class PandasLLM:
302
  logging.info(f"\n--- Calling Gemini API with prompt (first 500 chars of last message): ---\n{contents_for_api[-1]['parts'][0]['text'][:500]}...\n-------------------------------------------------------\n")
303
 
304
  try:
305
- # Using asyncio.to_thread for the blocking SDK call
306
  response = await asyncio.to_thread(
307
  self.model.generate_content,
308
- contents=contents_for_api, # Pass the constructed content
309
  generation_config=gen_config_obj,
310
- # stream=False # Ensure non-streaming for this setup
311
  )
312
 
313
  if response.prompt_feedback and response.prompt_feedback.block_reason:
314
  reason = response.prompt_feedback.block_reason
315
- reason_name = getattr(reason, 'name', str(reason)) # Handle if reason is enum or string
316
  logging.warning(f"Gemini API call blocked by prompt feedback: {reason_name}")
317
  return f"# Error: Prompt blocked due to content policy: {reason_name}."
318
 
319
- # Try to extract text, accounting for different response structures
320
  llm_output = ""
321
  if hasattr(response, 'text') and response.text:
322
  llm_output = response.text
@@ -325,12 +313,11 @@ class PandasLLM:
325
  if candidate.content and candidate.content.parts:
326
  llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
327
 
328
- # Check finish reason if output is empty
329
  if not llm_output:
330
  finish_reason_val = candidate.finish_reason
331
- finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val)) # Handle enum or string
332
  logging.warning(f"No text content in response candidate. Finish reason: {finish_reason}")
333
- if finish_reason == "SAFETY": # Check against genai_types.FinishReason.SAFETY if available
334
  return f"# Error: Response generation stopped due to safety reasons ({finish_reason})."
335
  elif finish_reason == "RECITATION":
336
  return f"# Error: Response generation stopped due to recitation policy ({finish_reason})."
@@ -342,12 +329,11 @@ class PandasLLM:
342
  logging.info(f"--- Gemini API Response (first 300 chars): ---\n{llm_output[:300]}...\n--------------------------------------------------\n")
343
  return llm_output
344
 
345
- except AttributeError as ae: # Catch issues with dummy genai objects if API key missing
346
  logging.error(f"AttributeError during Gemini call (likely due to missing API key/dummy objects): {ae}", exc_info=True)
347
  return f"# Error (Attribute): {type(ae).__name__} - {ae}. Check if GEMINI_API_KEY is set and google.genai library is correctly installed."
348
  except Exception as e:
349
  logging.error(f"Error calling Gemini API: {e}", exc_info=True)
350
- # More specific error messages
351
  if "API_KEY_INVALID" in str(e) or "API key not valid" in str(e):
352
  return "# Error: Gemini API key is not valid. Please check your GEMINI_API_KEY environment variable."
353
  if "400" in str(e) and "model" in str(e).lower() and ("not found" in str(e).lower() or "does not exist" in str(e).lower()):
@@ -360,59 +346,51 @@ class PandasLLM:
360
 
361
 
362
  async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str:
363
- """
364
- Sends a prompt to the LLM and optionally executes the returned Python code in a sandbox.
365
- dataframes_dict: Keys are 'base_name' (e.g., 'profiles'), values are pd.DataFrame.
366
- In exec, they are available as 'df_base_name'.
367
- history: Optional chat history for conversational context.
368
- """
369
  llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history)
370
 
371
  if self.force_sandbox:
372
- # Attempt to extract Python code block
373
  code_to_execute = ""
374
  if "```python" in llm_response_text:
375
  try:
376
  code_to_execute = llm_response_text.split("```python\n", 1)[1].split("\n```", 1)[0]
377
  except IndexError:
378
- # This might happen if the format is slightly off, e.g. no newline after ```python
379
  try:
380
  code_to_execute = llm_response_text.split("```python", 1)[1].split("```", 1)[0]
381
- if code_to_execute.startswith("\n"): code_to_execute = code_to_execute[1:] # remove leading newline
382
- if code_to_execute.endswith("\n"): code_to_execute = code_to_execute[:-1] # remove trailing newline
383
-
384
  except IndexError:
385
- code_to_execute = "" # Fallback, code not extracted
386
  logging.warning("Could not extract Python code using primary or secondary split method.")
387
 
 
388
  if llm_response_text.startswith("# Error:") or not code_to_execute:
389
  error_prefix = "LLM did not return a valid Python code block or an error occurred."
390
  if llm_response_text.startswith("# Error:"): error_prefix = "An error occurred during LLM call."
391
  elif not code_to_execute: error_prefix = "Could not extract Python code from LLM response."
392
 
393
- # Sanitize llm_response_text before printing to avoid breaking f-string or print
394
  safe_llm_response = str(llm_response_text).replace("'''", "'").replace('"""', '"')
395
- code_for_error_display = f"print(f'''{error_prefix}\\nRaw LLM Response (may be truncated):\\n{safe_llm_response[:1000]}''')"
396
  logging.warning(f"Problem with LLM response for sandbox: {error_prefix}")
397
- # Fallback to printing the raw response or error
398
- llm_response_text_for_sandbox_error = code_for_error_display
399
-
400
-
401
  logging.info(f"\n--- Code to Execute (from LLM, if sandbox): ---\n{code_to_execute}\n------------------------------------------------\n")
402
 
403
- # Define a restricted set of built-ins
404
- safe_builtins = {name: obj for name, obj in __builtins__.__dict__.items() if not name.startswith('_')}
405
- # More aggressive removal (example, adjust as needed for security)
406
- # For a web app, this sandboxing is CRITICAL.
407
- # Consider using a dedicated sandboxing library if security is paramount.
 
 
 
 
 
408
  unsafe_builtins = ['eval', 'exec', 'open', 'compile', 'input', 'memoryview', 'vars', 'globals', 'locals', '__import__']
409
  for ub in unsafe_builtins:
410
  safe_builtins.pop(ub, None)
411
 
412
- # Prepare globals for exec: pandas, numpy, dataframes, and restricted builtins
413
  exec_globals = {'pd': pd, 'np': np, '__builtins__': safe_builtins}
414
  for name, df_instance in dataframes_dict.items():
415
- exec_globals[f"df_{name}"] = df_instance # e.g. df_follower_stats, df_posts
416
 
417
  from io import StringIO
418
  import sys
@@ -421,12 +399,12 @@ class PandasLLM:
421
 
422
  final_output_str = ""
423
  try:
424
- if code_to_execute: # Only execute if code was extracted
425
- exec(code_to_execute, exec_globals, {}) # Empty locals
426
  output_val = captured_output.getvalue()
427
  final_output_str = output_val if output_val else "# Code executed successfully, but no explicit print() output was generated by the code."
428
  logging.info(f"--- Sandbox Execution Output: ---\n{final_output_str}\n-------------------------\n")
429
- else: # No code to execute, use the prepared error message
430
  exec(llm_response_text_for_sandbox_error, exec_globals, {})
431
  final_output_str = captured_output.getvalue()
432
  logging.warning(f"--- Sandbox Fallback Output (No Code Executed): ---\n{final_output_str}\n-------------------------\n")
@@ -434,11 +412,11 @@ class PandasLLM:
434
  except Exception as e:
435
  error_msg = f"# Error executing LLM-generated code:\n# {type(e).__name__}: {str(e)}\n# --- Code that caused error: ---\n{textwrap.indent(code_to_execute, '# ')}"
436
  final_output_str = error_msg
437
- logging.error(error_msg, exc_info=False) # exc_info=False to avoid huge traceback in Gradio UI
438
  finally:
439
- sys.stdout = old_stdout # Reset stdout
440
  return final_output_str
441
- else: # Not force_sandbox, return LLM text directly
442
  return llm_response_text
443
 
444
 
@@ -449,13 +427,12 @@ class EmployerBrandingAgent:
449
  data_privacy=True, force_sandbox=True):
450
  self.pandas_llm = PandasLLM(llm_model_name, generation_config_params, safety_settings, data_privacy, force_sandbox)
451
  self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
452
- self.all_dataframes = all_dataframes # Keys are 'base_name', values are pd.DataFrame
453
  self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
454
- self.chat_history = [] # Stores conversation history for this agent instance
455
  logging.info("EmployerBrandingAgent Initialized.")
456
 
457
  def _build_prompt(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
458
- # Base prompt
459
  prompt = f"You are a helpful and expert '{role}'. Your primary goal is to assist with analyzing LinkedIn-related data using Pandas DataFrames.\n"
460
  prompt += "You will be provided with schemas for available Pandas DataFrames and a user query.\n"
461
 
@@ -471,13 +448,12 @@ class EmployerBrandingAgent:
471
  prompt += "If the query is ambiguous or requires clarification, ask for it instead of making assumptions. If the query cannot be answered with the given data, state that clearly.\n"
472
  prompt += "If the query is not about data analysis or code generation (e.g. 'hello', 'how are you?'), respond politely and briefly, do not attempt to generate code.\n"
473
  prompt += "Structure your code clearly. Add comments (#) to explain each step of your logic.\n"
474
- else: # Textual response mode
475
  prompt += "Your task is to analyze the data and provide a comprehensive textual answer to the user query. You can explain your reasoning step-by-step.\n"
476
 
477
  prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
478
  prompt += self.schemas_representation
479
 
480
- # RAG Context (only add if relevant context is found)
481
  rag_context = self.rag_system.retrieve_relevant_info(user_query)
482
  if rag_context and "[RAG Context]" in rag_context and "No specific pre-defined context found" not in rag_context and "No highly relevant passages found" not in rag_context:
483
  prompt += f"\n--- ADDITIONAL CONTEXT (from internal knowledge base, consider this information) ---\n{rag_context}\n"
@@ -497,7 +473,7 @@ class EmployerBrandingAgent:
497
  prompt += "5. Ensure output: Use `print()` for all results that should be displayed. For DataFrames, you can print the DataFrame directly, or `df.to_string()` if it's large.\n"
498
  prompt += "6. Review: Check for correctness, efficiency, and adherence to the prompt (especially the `print()` requirement).\n"
499
  prompt += "7. Generate ONLY the Python code block starting with ```python and ending with ```. No explanations outside the code block's comments.\n"
500
- else: # Textual CoT
501
  prompt += "\n--- INSTRUCTIONS FOR RESPONSE (Chain of Thought) ---\n"
502
  prompt += "1. Understand the query fully.\n"
503
  prompt += "2. Identify the relevant data sources (DataFrames and columns).\n"
@@ -512,22 +488,15 @@ class EmployerBrandingAgent:
512
  logging.info(f"\n=== Processing Query for Role: {role} ===")
513
  logging.info(f"User Query: {user_query}")
514
 
515
- # Add user query to chat history
516
  self.chat_history.append({"role": "user", "content": user_query})
517
 
518
  full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
519
 
520
- # Pass relevant parts of chat history to pandas_llm.query if needed for context
521
- # For now, PandasLLM's _call_gemini_api_async is set up for single turn for code gen,
522
- # but can be adapted if conversational context for code gen becomes important.
523
- # The full_prompt itself is rebuilt each time, incorporating the latest user_query.
524
- response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=self.chat_history[:-1]) # Pass history excluding current query
525
 
526
- # Add assistant response to chat history
527
  self.chat_history.append({"role": "assistant", "content": response_text})
528
 
529
- # Limit history size to avoid overly long prompts in future turns (e.g., last 10 messages)
530
- MAX_HISTORY_TURNS = 5 # 5 pairs of user/assistant messages
531
  if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
532
  self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
533
 
 
274
  logging.error("PandasLLM: Gemini model not initialized. Cannot call API.")
275
  return "# Error: Gemini model not available. Check API key and configuration."
276
 
 
 
 
 
 
 
 
277
  contents_for_api = [{"role": "user", "parts": [{"text": prompt_text}]}]
278
+ if history:
 
 
279
  formatted_history = []
280
  for entry in history:
281
+ role = entry.get("role", "user")
282
+ if role == "assistant": role = "model"
283
  formatted_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
284
  contents_for_api = formatted_history + contents_for_api
285
 
 
293
  logging.info(f"\n--- Calling Gemini API with prompt (first 500 chars of last message): ---\n{contents_for_api[-1]['parts'][0]['text'][:500]}...\n-------------------------------------------------------\n")
294
 
295
  try:
 
296
  response = await asyncio.to_thread(
297
  self.model.generate_content,
298
+ contents=contents_for_api,
299
  generation_config=gen_config_obj,
 
300
  )
301
 
302
  if response.prompt_feedback and response.prompt_feedback.block_reason:
303
  reason = response.prompt_feedback.block_reason
304
+ reason_name = getattr(reason, 'name', str(reason))
305
  logging.warning(f"Gemini API call blocked by prompt feedback: {reason_name}")
306
  return f"# Error: Prompt blocked due to content policy: {reason_name}."
307
 
 
308
  llm_output = ""
309
  if hasattr(response, 'text') and response.text:
310
  llm_output = response.text
 
313
  if candidate.content and candidate.content.parts:
314
  llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
315
 
 
316
  if not llm_output:
317
  finish_reason_val = candidate.finish_reason
318
+ finish_reason = getattr(finish_reason_val, 'name', str(finish_reason_val))
319
  logging.warning(f"No text content in response candidate. Finish reason: {finish_reason}")
320
+ if finish_reason == "SAFETY":
321
  return f"# Error: Response generation stopped due to safety reasons ({finish_reason})."
322
  elif finish_reason == "RECITATION":
323
  return f"# Error: Response generation stopped due to recitation policy ({finish_reason})."
 
329
  logging.info(f"--- Gemini API Response (first 300 chars): ---\n{llm_output[:300]}...\n--------------------------------------------------\n")
330
  return llm_output
331
 
332
+ except AttributeError as ae:
333
  logging.error(f"AttributeError during Gemini call (likely due to missing API key/dummy objects): {ae}", exc_info=True)
334
  return f"# Error (Attribute): {type(ae).__name__} - {ae}. Check if GEMINI_API_KEY is set and google.genai library is correctly installed."
335
  except Exception as e:
336
  logging.error(f"Error calling Gemini API: {e}", exc_info=True)
 
337
  if "API_KEY_INVALID" in str(e) or "API key not valid" in str(e):
338
  return "# Error: Gemini API key is not valid. Please check your GEMINI_API_KEY environment variable."
339
  if "400" in str(e) and "model" in str(e).lower() and ("not found" in str(e).lower() or "does not exist" in str(e).lower()):
 
346
 
347
 
348
  async def query(self, prompt_with_query_and_context: str, dataframes_dict: dict, history: list = None) -> str:
 
 
 
 
 
 
349
  llm_response_text = await self._call_gemini_api_async(prompt_with_query_and_context, history)
350
 
351
  if self.force_sandbox:
 
352
  code_to_execute = ""
353
  if "```python" in llm_response_text:
354
  try:
355
  code_to_execute = llm_response_text.split("```python\n", 1)[1].split("\n```", 1)[0]
356
  except IndexError:
 
357
  try:
358
  code_to_execute = llm_response_text.split("```python", 1)[1].split("```", 1)[0]
359
+ if code_to_execute.startswith("\n"): code_to_execute = code_to_execute[1:]
360
+ if code_to_execute.endswith("\n"): code_to_execute = code_to_execute[:-1]
 
361
  except IndexError:
362
+ code_to_execute = ""
363
  logging.warning("Could not extract Python code using primary or secondary split method.")
364
 
365
+ llm_response_text_for_sandbox_error = "" # Initialize this variable
366
  if llm_response_text.startswith("# Error:") or not code_to_execute:
367
  error_prefix = "LLM did not return a valid Python code block or an error occurred."
368
  if llm_response_text.startswith("# Error:"): error_prefix = "An error occurred during LLM call."
369
  elif not code_to_execute: error_prefix = "Could not extract Python code from LLM response."
370
 
 
371
  safe_llm_response = str(llm_response_text).replace("'''", "'").replace('"""', '"')
372
+ llm_response_text_for_sandbox_error = f"print(f'''{error_prefix}\\nRaw LLM Response (may be truncated):\\n{safe_llm_response[:1000]}''')"
373
  logging.warning(f"Problem with LLM response for sandbox: {error_prefix}")
374
+
 
 
 
375
  logging.info(f"\n--- Code to Execute (from LLM, if sandbox): ---\n{code_to_execute}\n------------------------------------------------\n")
376
 
377
+ # --- THIS IS THE CORRECTED SECTION ---
378
+ # In the exec environment, __builtins__ is a dict.
379
+ # We iterate over its items directly.
380
+ safe_builtins = {}
381
+ if isinstance(__builtins__, dict):
382
+ safe_builtins = {name: obj for name, obj in __builtins__.items() if not name.startswith('_')}
383
+ else: # Fallback if __builtins__ is a module (e.g. in standard Python interpreter)
384
+ safe_builtins = {name: obj for name, obj in __builtins__.__dict__.items() if not name.startswith('_')}
385
+ # --- END OF CORRECTION ---
386
+
387
  unsafe_builtins = ['eval', 'exec', 'open', 'compile', 'input', 'memoryview', 'vars', 'globals', 'locals', '__import__']
388
  for ub in unsafe_builtins:
389
  safe_builtins.pop(ub, None)
390
 
 
391
  exec_globals = {'pd': pd, 'np': np, '__builtins__': safe_builtins}
392
  for name, df_instance in dataframes_dict.items():
393
+ exec_globals[f"df_{name}"] = df_instance
394
 
395
  from io import StringIO
396
  import sys
 
399
 
400
  final_output_str = ""
401
  try:
402
+ if code_to_execute:
403
+ exec(code_to_execute, exec_globals, {})
404
  output_val = captured_output.getvalue()
405
  final_output_str = output_val if output_val else "# Code executed successfully, but no explicit print() output was generated by the code."
406
  logging.info(f"--- Sandbox Execution Output: ---\n{final_output_str}\n-------------------------\n")
407
+ else:
408
  exec(llm_response_text_for_sandbox_error, exec_globals, {})
409
  final_output_str = captured_output.getvalue()
410
  logging.warning(f"--- Sandbox Fallback Output (No Code Executed): ---\n{final_output_str}\n-------------------------\n")
 
412
  except Exception as e:
413
  error_msg = f"# Error executing LLM-generated code:\n# {type(e).__name__}: {str(e)}\n# --- Code that caused error: ---\n{textwrap.indent(code_to_execute, '# ')}"
414
  final_output_str = error_msg
415
+ logging.error(error_msg, exc_info=False)
416
  finally:
417
+ sys.stdout = old_stdout
418
  return final_output_str
419
+ else:
420
  return llm_response_text
421
 
422
 
 
427
  data_privacy=True, force_sandbox=True):
428
  self.pandas_llm = PandasLLM(llm_model_name, generation_config_params, safety_settings, data_privacy, force_sandbox)
429
  self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
430
+ self.all_dataframes = all_dataframes
431
  self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
432
+ self.chat_history = []
433
  logging.info("EmployerBrandingAgent Initialized.")
434
 
435
  def _build_prompt(self, user_query: str, role="Employer Branding Analyst", task_decomposition_hint=None, cot_hint=True) -> str:
 
436
  prompt = f"You are a helpful and expert '{role}'. Your primary goal is to assist with analyzing LinkedIn-related data using Pandas DataFrames.\n"
437
  prompt += "You will be provided with schemas for available Pandas DataFrames and a user query.\n"
438
 
 
448
  prompt += "If the query is ambiguous or requires clarification, ask for it instead of making assumptions. If the query cannot be answered with the given data, state that clearly.\n"
449
  prompt += "If the query is not about data analysis or code generation (e.g. 'hello', 'how are you?'), respond politely and briefly, do not attempt to generate code.\n"
450
  prompt += "Structure your code clearly. Add comments (#) to explain each step of your logic.\n"
451
+ else:
452
  prompt += "Your task is to analyze the data and provide a comprehensive textual answer to the user query. You can explain your reasoning step-by-step.\n"
453
 
454
  prompt += "\n--- AVAILABLE DATA AND SCHEMAS ---\n"
455
  prompt += self.schemas_representation
456
 
 
457
  rag_context = self.rag_system.retrieve_relevant_info(user_query)
458
  if rag_context and "[RAG Context]" in rag_context and "No specific pre-defined context found" not in rag_context and "No highly relevant passages found" not in rag_context:
459
  prompt += f"\n--- ADDITIONAL CONTEXT (from internal knowledge base, consider this information) ---\n{rag_context}\n"
 
473
  prompt += "5. Ensure output: Use `print()` for all results that should be displayed. For DataFrames, you can print the DataFrame directly, or `df.to_string()` if it's large.\n"
474
  prompt += "6. Review: Check for correctness, efficiency, and adherence to the prompt (especially the `print()` requirement).\n"
475
  prompt += "7. Generate ONLY the Python code block starting with ```python and ending with ```. No explanations outside the code block's comments.\n"
476
+ else:
477
  prompt += "\n--- INSTRUCTIONS FOR RESPONSE (Chain of Thought) ---\n"
478
  prompt += "1. Understand the query fully.\n"
479
  prompt += "2. Identify the relevant data sources (DataFrames and columns).\n"
 
488
  logging.info(f"\n=== Processing Query for Role: {role} ===")
489
  logging.info(f"User Query: {user_query}")
490
 
 
491
  self.chat_history.append({"role": "user", "content": user_query})
492
 
493
  full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
494
 
495
+ response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=self.chat_history[:-1])
 
 
 
 
496
 
 
497
  self.chat_history.append({"role": "assistant", "content": response_text})
498
 
499
+ MAX_HISTORY_TURNS = 5
 
500
  if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
501
  self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
502