Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +30 -16
eb_agent_module.py
CHANGED
@@ -295,13 +295,16 @@ class PandasLLM:
|
|
295 |
return "# Error: Gemini model service not available for API call."
|
296 |
|
297 |
gemini_history = []
|
298 |
-
if history:
|
299 |
for entry in history:
|
300 |
-
role
|
301 |
-
|
|
|
|
|
|
|
302 |
|
303 |
-
|
304 |
-
contents_for_api = gemini_history +
|
305 |
|
306 |
# Prepare model ID (e.g., "models/gemini-2.0-flash")
|
307 |
model_id_for_api = self.llm_model_name
|
@@ -327,10 +330,8 @@ class PandasLLM:
|
|
327 |
safety_settings=self.safety_settings_list
|
328 |
)
|
329 |
|
330 |
-
# ... (Response parsing logic remains largely the same as before) ...
|
331 |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
|
332 |
hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
|
333 |
-
# ... block reason handling ...
|
334 |
block_reason_val = response.prompt_feedback.block_reason
|
335 |
block_reason_str = str(block_reason_val.name if hasattr(block_reason_val, 'name') else block_reason_val)
|
336 |
logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}.")
|
@@ -345,14 +346,18 @@ class PandasLLM:
|
|
345 |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
|
346 |
|
347 |
if not llm_output and candidate.finish_reason:
|
348 |
-
# ... finish reason handling ...
|
349 |
finish_reason_val = candidate.finish_reason
|
350 |
finish_reason_str = str(finish_reason_val.name if hasattr(finish_reason_val, 'name') else finish_reason_val)
|
351 |
|
352 |
-
if finish_reason_str == "SAFETY":
|
353 |
-
|
354 |
-
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
356 |
|
357 |
logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
|
358 |
return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
|
@@ -434,7 +439,7 @@ class EmployerBrandingAgent:
|
|
434 |
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
|
435 |
self.all_dataframes = all_dataframes if all_dataframes else {}
|
436 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
437 |
-
self.chat_history = []
|
438 |
logging.info("EmployerBrandingAgent Initialized (using Client API approach).")
|
439 |
|
440 |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
@@ -480,12 +485,21 @@ class EmployerBrandingAgent:
|
|
480 |
return prompt
|
481 |
|
482 |
async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
483 |
-
current_turn_history_for_llm = self.chat_history[:]
|
484 |
-
|
|
|
|
|
|
|
485 |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
|
486 |
logging.info(f"Built prompt for query: {user_query[:100]}...")
|
|
|
|
|
487 |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
|
488 |
-
|
|
|
|
|
|
|
|
|
489 |
MAX_HISTORY_TURNS = 5
|
490 |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
|
491 |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
|
|
|
295 |
return "# Error: Gemini model service not available for API call."
|
296 |
|
297 |
gemini_history = []
|
298 |
+
if history: # history is agent.chat_history, which will now have 'content'
|
299 |
for entry in history:
|
300 |
+
# Standardize role for API: 'model' for assistant responses
|
301 |
+
role_for_api = "model" if entry.get("role") == "assistant" else entry.get("role", "user")
|
302 |
+
# Get text from 'content' key
|
303 |
+
text_content = entry.get("content", "")
|
304 |
+
gemini_history.append({"role": role_for_api, "parts": [{"text": text_content}]})
|
305 |
|
306 |
+
current_prompt_content = [{"role": "user", "parts": [{"text": prompt_text}]}]
|
307 |
+
contents_for_api = gemini_history + current_prompt_content
|
308 |
|
309 |
# Prepare model ID (e.g., "models/gemini-2.0-flash")
|
310 |
model_id_for_api = self.llm_model_name
|
|
|
330 |
safety_settings=self.safety_settings_list
|
331 |
)
|
332 |
|
|
|
333 |
if hasattr(response, 'prompt_feedback') and response.prompt_feedback and \
|
334 |
hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
|
|
|
335 |
block_reason_val = response.prompt_feedback.block_reason
|
336 |
block_reason_str = str(block_reason_val.name if hasattr(block_reason_val, 'name') else block_reason_val)
|
337 |
logging.warning(f"Prompt blocked by API. Reason: {block_reason_str}.")
|
|
|
346 |
llm_output = "".join(part.text for part in candidate.content.parts if hasattr(part, 'text'))
|
347 |
|
348 |
if not llm_output and candidate.finish_reason:
|
|
|
349 |
finish_reason_val = candidate.finish_reason
|
350 |
finish_reason_str = str(finish_reason_val.name if hasattr(finish_reason_val, 'name') else finish_reason_val)
|
351 |
|
352 |
+
if finish_reason_str == "SAFETY":
|
353 |
+
safety_messages = []
|
354 |
+
if hasattr(candidate, 'safety_ratings') and candidate.safety_ratings:
|
355 |
+
for rating in candidate.safety_ratings:
|
356 |
+
cat_name = rating.category.name if hasattr(rating.category, 'name') else str(rating.category)
|
357 |
+
prob_name = rating.probability.name if hasattr(rating.probability, 'name') else str(rating.probability)
|
358 |
+
safety_messages.append(f"Category: {cat_name}, Probability: {prob_name}")
|
359 |
+
logging.warning(f"Content generation stopped due to safety. Finish reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}")
|
360 |
+
return f"# Error: Content generation stopped by API due to safety. Finish Reason: {finish_reason_str}. Details: {'; '.join(safety_messages)}"
|
361 |
|
362 |
logging.warning(f"Empty response from LLM. Finish reason: {finish_reason_str}.")
|
363 |
return f"# Error: LLM returned an empty response. Finish reason: {finish_reason_str}."
|
|
|
439 |
self.rag_system = AdvancedRAGSystem(rag_documents_df, embedding_model_name)
|
440 |
self.all_dataframes = all_dataframes if all_dataframes else {}
|
441 |
self.schemas_representation = get_all_schemas_representation(self.all_dataframes)
|
442 |
+
self.chat_history = [] # This will store entries like {"role": "user/assistant", "content": "text"}
|
443 |
logging.info("EmployerBrandingAgent Initialized (using Client API approach).")
|
444 |
|
445 |
def _build_prompt(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
|
|
485 |
return prompt
|
486 |
|
487 |
async def process_query(self, user_query: str, role="Employer Branding Analyst & Strategist", task_decomposition_hint=None, cot_hint=True) -> str:
|
488 |
+
current_turn_history_for_llm = self.chat_history[:] # History before this turn
|
489 |
+
|
490 |
+
# Append new user message to chat_history using 'content' key
|
491 |
+
self.chat_history.append({"role": "user", "content": user_query})
|
492 |
+
|
493 |
full_prompt = self._build_prompt(user_query, role, task_decomposition_hint, cot_hint)
|
494 |
logging.info(f"Built prompt for query: {user_query[:100]}...")
|
495 |
+
|
496 |
+
# Pass history (which now uses 'content') to pandas_llm.query
|
497 |
response_text = await self.pandas_llm.query(full_prompt, self.all_dataframes, history=current_turn_history_for_llm)
|
498 |
+
|
499 |
+
# Append new assistant message to chat_history using 'content' key
|
500 |
+
# Standardize role to 'assistant' for model responses in history
|
501 |
+
self.chat_history.append({"role": "assistant", "content": response_text})
|
502 |
+
|
503 |
MAX_HISTORY_TURNS = 5
|
504 |
if len(self.chat_history) > MAX_HISTORY_TURNS * 2:
|
505 |
self.chat_history = self.chat_history[-(MAX_HISTORY_TURNS * 2):]
|