GuglielmoTor commited on
Commit
911f78e
·
verified ·
1 Parent(s): 9ce5589

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +30 -16
eb_agent_module.py CHANGED
@@ -295,13 +295,16 @@ class PandasLLM:
295
  return "# Error: Gemini model service not available for API call."
296
 
297
  gemini_history = []
298
- if history:
299
  for entry in history:
300
- role = "model" if entry.get("role") == "assistant" else entry.get("role", "user")
301
- gemini_history.append({"role": role, "parts": [{"text": entry.get("content", "")}]})
 
 
 
302
 
303
- current_content = [{"role": "user", "parts": [{"text": prompt_text}]}]
304
- contents_for_api = gemini_history + current_content
305
 
306
  # Prepare model ID (e.g., "models/gemini-2.0-flash")
307
  model_id_for_api = self.llm_model_name
@@ -327,10 +330,8 @@ class PandasLLM:
327
  safety_settings=self.safety_settings_list
328
  )
329
 
330
- # ... (Response parsing logic remains largely the same as before) ...
331
  if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
332
  hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
333
- # ... block reason handling ...
334
  block_reason_val = response.prompt_feedback.block_reason
335
  block_reason_str = str(block_reason_val.name if hasattr(block_reason_val, 'name') else block_reason_val)
336
  logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}.")
@@ -345,14 +346,18 @@ class PandasLLM:
345
  llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
346
 
347
  if not llm_output and candidate.finish_reason:
348
- # ... finish reason handling ...
349
  finish_reason_val = candidate.finish_reason
350
  finish_reason_str = str(finish_reason_val.name if hasattr(finish_reason_val, 'name') else finish_reason_val)
351
 
352
- if finish_reason_str == "SAFETY": # or candidate.finish_reason == genai_types.FinishReason.SAFETY:
353
- # ... safety message handling ...
354
- logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}.")
355
- return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}."
 
 
 
 
 
356
 
357
  logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
358
  return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
@@ -434,7 +439,7 @@ class EmployerBrandingAgent:
434
  self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
435
  self.all_dataframes = all_dataframes if all_dataframes else {}
436
  self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
437
- self.chat_history = []
438
  logging.info("EmployerBrandingAgent Initialized (using Client API approach).")
439
 
440
  def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
@@ -480,12 +485,21 @@ class EmployerBrandingAgent:
480
  return prompt
481
 
482
  async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
483
- current_turn_history_for_llm = self.chat_history[:]
484
- self.chat_history.append({"role": "user", "parts": [{"text": user_query}]})
 
 
 
485
  full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
486
  logging.info(f"Built prompt for query: {user_query[:100]}...")
 
 
487
  response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
488
- self.chat_history.append({"role": "model", "parts": [{"text": response_text}]})
 
 
 
 
489
  MAX_HISTORY_TURNS = 5
490
  if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
491
  self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
 
295
  return "# Error: Gemini model service not available for API call."
296
 
297
  gemini_history = []
298
+ if history: # history is agent.chat_history, which will now have 'content'
299
  for entry in history:
300
+ # Standardize role for API: 'model' for assistant responses
301
+ role_for_api = "model" if entry.get("role") == "assistant" else entry.get("role", "user")
302
+ # Get text from 'content' key
303
+ text_content = entry.get("content", "")
304
+ gemini_history.append({"role": role_for_api, "parts": [{"text": text_content}]})
305
 
306
+ current_prompt_content = [{"role": "user", "parts": [{"text": prompt_text}]}]
307
+ contents_for_api = gemini_history + current_prompt_content
308
 
309
  # Prepare model ID (e.g., "models/gemini-2.0-flash")
310
  model_id_for_api = self.llm_model_name
 
330
  safety_settings=self.safety_settings_list
331
  )
332
 
 
333
  if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
334
  hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
 
335
  block_reason_val = response.prompt_feedback.block_reason
336
  block_reason_str = str(block_reason_val.name if hasattr(block_reason_val, 'name') else block_reason_val)
337
  logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}.")
 
346
  llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
347
 
348
  if not llm_output and candidate.finish_reason:
 
349
  finish_reason_val = candidate.finish_reason
350
  finish_reason_str = str(finish_reason_val.name if hasattr(finish_reason_val, 'name') else finish_reason_val)
351
 
352
+ if finish_reason_str == "SAFETY":
353
+ safety_messages = []
354
+ if hasattr(candidate, 'safety_ratings') and candidate.safety_ratings:
355
+ for rating in candidate.safety_ratings:
356
+ cat_name = rating.category.name if hasattr(rating.category, 'name') else str(rating.category)
357
+ prob_name = rating.probability.name if hasattr(rating.probability, 'name') else str(rating.probability)
358
+ safety_messages.append(f"Category: {cat_name}, Probability: {prob_name}")
359
+ logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}")
360
+ return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}"
361
 
362
  logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
363
  return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
 
439
  self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
440
  self.all_dataframes = all_dataframes if all_dataframes else {}
441
  self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
442
+ self.chat_history = [] # This will store entries like {"role": "user/assistant", "content": "text"}
443
  logging.info("EmployerBrandingAgent Initialized (using Client API approach).")
444
 
445
  def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
 
485
  return prompt
486
 
487
  async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
488
+ current_turn_history_for_llm = self.chat_history[:] # History before this turn
489
+
490
+ # Append new user message to chat_history using 'content' key
491
+ self.chat_history.append({"role": "user", "content": user_query})
492
+
493
  full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
494
  logging.info(f"Built prompt for query: {user_query[:100]}...")
495
+
496
+ # Pass history (which now uses 'content') to pandas_llm.query
497
  response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
498
+
499
+ # Append new assistant message to chat_history using 'content' key
500
+ # Standardize role to 'assistant' for model responses in history
501
+ self.chat_history.append({"role": "assistant", "content": response_text})
502
+
503
  MAX_HISTORY_TURNS = 5
504
  if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
505
  self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]