GuglielmoTor commited on
Commit
5cec9b7
·
verified ·
1 Parent(s): 01ce8fd

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +33 -9
eb_agent_module.py CHANGED
@@ -433,15 +433,23 @@ class PandasLLM:
433
 
434
  # --- Employer Branding Agent ---
435
  class EmployerBrandingAgent:
436
- def __init__(self, llm_model_name: str, gc_dict: dict, ss_list: list, all_dfs: dict, rag_df: pd.DataFrame, emb_m_name: str, dp=True, fs=True):
437
- self.pandas_llm = PandasLLM(llm_model_name, gc_dict, ss_list, dp, fs)
438
- self.rag_system = AdvancedRAGSystem(rag_df, emb_m_name)
439
- self.all_dataframes = all_dfs if all_dfs else {}
 
 
 
 
 
 
 
 
440
  self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
441
  self.chat_history = []
442
  logging.info(f"EmployerBrandingAgent Initialized (Real GenAI Loaded: {_REAL_GENAI_LOADED}).")
443
 
444
- def _build_prompt(self, user_query: str, role="EB Analyst", task_hint=None, cot=True) -> str:
445
  prompt = f"You are '{role}'. Goal: insights from DataFrames & RAG.\n"
446
  if self.pandas_llm.data_privacy: prompt += "PRIVACY: Summarize/aggregate PII.\n"
447
  if self.pandas_llm.force_sandbox:
@@ -461,7 +469,7 @@ class EmployerBrandingAgent:
461
  else: prompt += "--- TEXT RESPONSE THOUGHT PROCESS ---\n1.Goal? 2.Data? 3.Insights (DFs+RAG). 4.Structure response.\n"
462
  return prompt
463
 
464
- async def process_query(self, user_query: str, role="EB Analyst", task_hint=None, cot=True) -> str:
465
  hist_for_llm = self.chat_history[:]
466
  self.chat_history.append({"role": "user", "content": user_query})
467
  prompt = self._build_prompt(user_query, role, task_hint, cot)
@@ -471,13 +479,29 @@ class EmployerBrandingAgent:
471
  if len(self.chat_history) > 10: self.chat_history = self.chat_history[-10:]; logging.info("Chat history truncated.")
472
  return response
473
 
474
- def update_dataframes(self, new_dfs: dict): self.all_dataframes = new_dfs if new_dfs else {}; self.schemas_representation = get_all_schemas_representation(self.all_dataframes); logging.info("Agent DFs updated.")
475
- def clear_chat_history(self): self.chat_history = []; logging.info("Agent chat history cleared.")
 
 
 
 
 
 
476
 
477
  # --- Example Usage (Conceptual) ---
478
  async def main_test():
479
  logging.info(f"Test (Real GenAI: {_REAL_GENAI_LOADED}, API Key: {bool(GEMINI_API_KEY)})")
480
- agent = EmployerBrandingAgent(LLM_MODEL_NAME, GENERATION_CONFIG_PARAMS, DEFAULT_SAFETY_SETTINGS, {}, df_rag_documents, GEMINI_EMBEDDING_MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
481
  for q in ["What are EB best practices?", "Hello Agent!"]:
482
  logging.info(f"\nQuery: {q}")
483
  resp = await agent.process_query(q)
 
433
 
434
  # --- Employer Branding Agent ---
435
  class EmployerBrandingAgent:
436
+ def __init__(self, llm_model_name: str,
437
+ generation_config_dict: dict, # Changed from gc_dict
438
+ safety_settings_list: list, # Changed from ss_list
439
+ all_dataframes: dict, # Changed from all_dfs
440
+ rag_documents_df: pd.DataFrame, # Changed from rag_df
441
+ embedding_model_name: str, # Changed from emb_m_name
442
+ data_privacy=True, # Changed from dp
443
+ force_sandbox=True): # Changed from fs
444
+
445
+ self.pandas_llm = PandasLLM(llm_model_name, generation_config_dict, safety_settings_list, data_privacy, force_sandbox)
446
+ self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
447
+ self.all_dataframes = all_dataframes if all_dataframes else {}
448
  self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
449
  self.chat_history = []
450
  logging.info(f"EmployerBrandingAgent Initialized (Real GenAI Loaded: {_REAL_GENAI_LOADED}).")
451
 
452
+ def _build_prompt(self, user_query: str, role="EB Analyst", task_hint=None, cot=True) -> str: # Simplified role for brevity in example
453
  prompt = f"You are '{role}'. Goal: insights from DataFrames & RAG.\n"
454
  if self.pandas_llm.data_privacy: prompt += "PRIVACY: Summarize/aggregate PII.\n"
455
  if self.pandas_llm.force_sandbox:
 
469
  else: prompt += "--- TEXT RESPONSE THOUGHT PROCESS ---\n1.Goal? 2.Data? 3.Insights (DFs+RAG). 4.Structure response.\n"
470
  return prompt
471
 
472
+ async def process_query(self, user_query: str, role="EB Analyst", task_hint=None, cot=True) -> str: # Simplified role
473
  hist_for_llm = self.chat_history[:]
474
  self.chat_history.append({"role": "user", "content": user_query})
475
  prompt = self._build_prompt(user_query, role, task_hint, cot)
 
479
  if len(self.chat_history) > 10: self.chat_history = self.chat_history[-10:]; logging.info("Chat history truncated.")
480
  return response
481
 
482
+ def update_dataframes(self, new_dataframes: dict): # Changed from new_dfs
483
+ self.all_dataframes = new_dataframes if new_dataframes else {}
484
+ self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
485
+ logging.info("Agent DFs updated.")
486
+
487
+ def clear_chat_history(self):
488
+ self.chat_history = []
489
+ logging.info("Agent chat history cleared.")
490
 
491
  # --- Example Usage (Conceptual) ---
492
  async def main_test():
493
  logging.info(f"Test (Real GenAI: {_REAL_GENAI_LOADED}, API Key: {bool(GEMINI_API_KEY)})")
494
+ # Example: Provide all parameters to EmployerBrandingAgent constructor
495
+ agent = EmployerBrandingAgent(
496
+ llm_model_name=LLM_MODEL_NAME,
497
+ generation_config_dict=GENERATION_CONFIG_PARAMS,
498
+ safety_settings_list=DEFAULT_SAFETY_SETTINGS,
499
+ all_dataframes={}, # Empty dict for example
500
+ rag_documents_df=df_rag_documents,
501
+ embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME,
502
+ data_privacy=True,
503
+ force_sandbox=True
504
+ )
505
  for q in ["What are EB best practices?", "Hello Agent!"]:
506
  logging.info(f"\nQuery: {q}")
507
  resp = await agent.process_query(q)