Spaces:
Running
Running
Update eb_agent_module.py
Browse files- eb_agent_module.py +131 -131
eb_agent_module.py
CHANGED
@@ -425,141 +425,141 @@ class EmployerBrandingAgent:
|
|
425 |
If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
|
426 |
""").strip()
|
427 |
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
)
|
454 |
-
|
455 |
-
# 2. Priming assistant message
|
456 |
-
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
|
457 |
|
458 |
-
|
459 |
-
for
|
460 |
-
|
461 |
-
|
462 |
-
#
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
# Create the config object with both generation config and safety settings
|
499 |
-
config_dict = {}
|
500 |
-
|
501 |
-
# Add generation config parameters
|
502 |
-
if self.generation_config_dict:
|
503 |
-
for key, value in self.generation_config_dict.items():
|
504 |
-
config_dict[key] = value
|
505 |
-
|
506 |
-
# Add safety settings
|
507 |
-
if self.safety_settings_list:
|
508 |
-
# Convert safety settings to the correct format if needed
|
509 |
-
safety_settings = []
|
510 |
-
for ss in self.safety_settings_list:
|
511 |
-
if isinstance(ss, dict):
|
512 |
-
# Convert dict to types.SafetySetting
|
513 |
-
safety_settings.append(types.SafetySetting(
|
514 |
-
category=ss.get('category'),
|
515 |
-
threshold=ss.get('threshold')
|
516 |
-
))
|
517 |
-
else:
|
518 |
-
safety_settings.append(ss)
|
519 |
-
config_dict['safety_settings'] = safety_settings
|
520 |
-
|
521 |
-
# Create the config object
|
522 |
-
config = types.GenerateContentConfig(**config_dict)
|
523 |
-
|
524 |
-
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
|
525 |
-
|
526 |
-
api_response = await asyncio.to_thread(
|
527 |
-
client.models.generate_content,
|
528 |
-
model=model_path,
|
529 |
-
contents=contents, # Simplified contents format
|
530 |
-
config=config # Using config parameter instead of separate generation_config and safety_settings
|
531 |
-
)
|
532 |
-
|
533 |
-
# Parse response from client.models structure
|
534 |
-
if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
|
535 |
-
response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
|
536 |
-
response_text = "".join(response_text_parts).strip()
|
537 |
-
else:
|
538 |
-
# Handle blocked or empty responses
|
539 |
-
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
|
540 |
-
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
|
541 |
-
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
|
542 |
-
if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
|
543 |
-
finish_reason = api_response.candidates[0].finish_reason
|
544 |
-
if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
|
545 |
-
logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
|
546 |
-
return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
|
547 |
-
return "I apologize, but I couldn't generate a response from client.models."
|
548 |
-
else:
|
549 |
-
raise ConnectionError("No valid LLM client or model instance available.")
|
550 |
-
|
551 |
-
return response_text.strip()
|
552 |
-
|
553 |
-
except Exception as e:
|
554 |
-
error_message = str(e).lower()
|
555 |
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
else:
|
561 |
-
|
562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
|
564 |
|
565 |
def _validate_query(self, query: str) -> bool:
|
|
|
425 |
If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
|
426 |
""").strip()
|
427 |
|
428 |
+
async def _generate_response(self, current_user_query: str) -> str:
|
429 |
+
"""
|
430 |
+
Generates a response from the LLM based on the current query, system prompts,
|
431 |
+
data summaries, RAG context, and the agent's chat history.
|
432 |
+
Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
|
433 |
+
"""
|
434 |
+
if not self.is_ready:
|
435 |
+
return "Agent is not ready. Please initialize."
|
436 |
+
if not client and not self.llm_model_instance:
|
437 |
+
return "Error: AI service is not available. Check API configuration."
|
438 |
+
|
439 |
+
try:
|
440 |
+
system_prompt_text = self._build_system_prompt()
|
441 |
+
data_summary_text = self._get_dataframes_summary()
|
442 |
+
rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25)
|
443 |
+
|
444 |
+
# Construct the messages for the LLM API call
|
445 |
+
llm_messages = []
|
446 |
+
|
447 |
+
# 1. System-level instructions and context (as a first "user" turn)
|
448 |
+
initial_context_prompt = (
|
449 |
+
f"{system_prompt_text}\n\n"
|
450 |
+
f"## Available Data Overview:\n{data_summary_text}\n\n"
|
451 |
+
f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
|
452 |
+
f"Given this context, please respond to the user queries that follow in the chat history."
|
453 |
+
)
|
454 |
+
llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
|
455 |
+
# 2. Priming assistant message
|
456 |
+
llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
|
457 |
+
|
458 |
+
# 3. Append the actual conversation history (already includes the current user query)
|
459 |
+
for entry in self.chat_history:
|
460 |
+
llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
|
461 |
+
|
462 |
+
# --- Make the API call ---
|
463 |
+
response_text = ""
|
464 |
+
if self.llm_model_instance: # Standard google-generativeai usage
|
465 |
+
logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
|
466 |
+
|
467 |
+
# Prepare generation config and safety settings for google-generativeai
|
468 |
+
gen_config_payload = self.generation_config_dict
|
469 |
+
safety_settings_payload = self.safety_settings_list
|
470 |
+
|
471 |
+
if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
|
472 |
+
try:
|
473 |
+
gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
|
474 |
+
except Exception as e:
|
475 |
+
logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
|
476 |
+
|
477 |
+
api_response = await self.llm_model_instance.generate_content_async(
|
478 |
+
contents=llm_messages,
|
479 |
+
generation_config=gen_config_payload,
|
480 |
+
safety_settings=safety_settings_payload
|
481 |
)
|
482 |
+
response_text = api_response.text
|
|
|
|
|
483 |
|
484 |
+
elif client: # google.genai client usage
|
485 |
+
logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
|
486 |
+
|
487 |
+
# Convert messages to the format expected by google.genai client
|
488 |
+
# The client expects a simpler contents format
|
489 |
+
contents = []
|
490 |
+
for msg in llm_messages:
|
491 |
+
if msg["role"] == "user":
|
492 |
+
contents.append(msg["parts"][0]["text"])
|
493 |
+
elif msg["role"] == "model":
|
494 |
+
# For model responses, we might need to handle differently
|
495 |
+
# but for now, let's include them as context
|
496 |
+
contents.append(f"Assistant: {msg['parts'][0]['text']}")
|
497 |
+
|
498 |
+
# Create the config object with both generation config and safety settings
|
499 |
+
config_dict = {}
|
500 |
+
|
501 |
+
# Add generation config parameters
|
502 |
+
if self.generation_config_dict:
|
503 |
+
for key, value in self.generation_config_dict.items():
|
504 |
+
config_dict[key] = value
|
505 |
+
|
506 |
+
# Add safety settings
|
507 |
+
if self.safety_settings_list:
|
508 |
+
# Convert safety settings to the correct format if needed
|
509 |
+
safety_settings = []
|
510 |
+
for ss in self.safety_settings_list:
|
511 |
+
if isinstance(ss, dict):
|
512 |
+
# Convert dict to types.SafetySetting
|
513 |
+
safety_settings.append(types.SafetySetting(
|
514 |
+
category=ss.get('category'),
|
515 |
+
threshold=ss.get('threshold')
|
516 |
+
))
|
517 |
+
else:
|
518 |
+
safety_settings.append(ss)
|
519 |
+
config_dict['safety_settings'] = safety_settings
|
520 |
+
|
521 |
+
# Create the config object
|
522 |
+
config = types.GenerateContentConfig(**config_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
+
model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
|
525 |
+
|
526 |
+
api_response = await asyncio.to_thread(
|
527 |
+
client.models.generate_content,
|
528 |
+
model=model_path,
|
529 |
+
contents=contents, # Simplified contents format
|
530 |
+
config=config # Using config parameter instead of separate generation_config and safety_settings
|
531 |
+
)
|
532 |
+
|
533 |
+
# Parse response from client.models structure
|
534 |
+
if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
|
535 |
+
response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
|
536 |
+
response_text = "".join(response_text_parts).strip()
|
537 |
else:
|
538 |
+
# Handle blocked or empty responses
|
539 |
+
if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
|
540 |
+
logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
|
541 |
+
return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
|
542 |
+
if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
|
543 |
+
finish_reason = api_response.candidates[0].finish_reason
|
544 |
+
if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
|
545 |
+
logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
|
546 |
+
return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
|
547 |
+
return "I apologize, but I couldn't generate a response from client.models."
|
548 |
+
else:
|
549 |
+
raise ConnectionError("No valid LLM client or model instance available.")
|
550 |
+
|
551 |
+
return response_text.strip()
|
552 |
+
|
553 |
+
except Exception as e:
|
554 |
+
error_message = str(e).lower()
|
555 |
+
|
556 |
+
# Check if it's a blocked prompt error by examining the error message
|
557 |
+
if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
|
558 |
+
logging.error(f"Blocked prompt from LLM: {e}", exc_info=True)
|
559 |
+
return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}"
|
560 |
+
else:
|
561 |
+
logging.error(f"Error in _generate_response: {e}", exc_info=True)
|
562 |
+
return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
|
563 |
|
564 |
|
565 |
def _validate_query(self, query: str) -> bool:
|