GuglielmoTor commited on
Commit
ba2acc2
·
verified ·
1 Parent(s): d350f74

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +131 -131
eb_agent_module.py CHANGED
@@ -425,141 +425,141 @@ class EmployerBrandingAgent:
425
  If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
426
  """).strip()
427
 
428
- async def _generate_response(self, current_user_query: str) -> str:
429
- """
430
- Generates a response from the LLM based on the current query, system prompts,
431
- data summaries, RAG context, and the agent's chat history.
432
- Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
433
- """
434
- if not self.is_ready:
435
- return "Agent is not ready. Please initialize."
436
- if not client and not self.llm_model_instance:
437
- return "Error: AI service is not available. Check API configuration."
438
-
439
- try:
440
- system_prompt_text = self._build_system_prompt()
441
- data_summary_text = self._get_dataframes_summary()
442
- rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25)
443
-
444
- # Construct the messages for the LLM API call
445
- llm_messages = []
446
-
447
- # 1. System-level instructions and context (as a first "user" turn)
448
- initial_context_prompt = (
449
- f"{system_prompt_text}\n\n"
450
- f"## Available Data Overview:\n{data_summary_text}\n\n"
451
- f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
452
- f"Given this context, please respond to the user queries that follow in the chat history."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  )
454
- llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
455
- # 2. Priming assistant message
456
- llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
457
 
458
- # 3. Append the actual conversation history (already includes the current user query)
459
- for entry in self.chat_history:
460
- llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
461
-
462
- # --- Make the API call ---
463
- response_text = ""
464
- if self.llm_model_instance: # Standard google-generativeai usage
465
- logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
466
-
467
- # Prepare generation config and safety settings for google-generativeai
468
- gen_config_payload = self.generation_config_dict
469
- safety_settings_payload = self.safety_settings_list
470
-
471
- if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
472
- try:
473
- gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
474
- except Exception as e:
475
- logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
476
-
477
- api_response = await self.llm_model_instance.generate_content_async(
478
- contents=llm_messages,
479
- generation_config=gen_config_payload,
480
- safety_settings=safety_settings_payload
481
- )
482
- response_text = api_response.text
483
-
484
- elif client: # google.genai client usage
485
- logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
486
-
487
- # Convert messages to the format expected by google.genai client
488
- # The client expects a simpler contents format
489
- contents = []
490
- for msg in llm_messages:
491
- if msg["role"] == "user":
492
- contents.append(msg["parts"][0]["text"])
493
- elif msg["role"] == "model":
494
- # For model responses, we might need to handle differently
495
- # but for now, let's include them as context
496
- contents.append(f"Assistant: {msg['parts'][0]['text']}")
497
-
498
- # Create the config object with both generation config and safety settings
499
- config_dict = {}
500
-
501
- # Add generation config parameters
502
- if self.generation_config_dict:
503
- for key, value in self.generation_config_dict.items():
504
- config_dict[key] = value
505
-
506
- # Add safety settings
507
- if self.safety_settings_list:
508
- # Convert safety settings to the correct format if needed
509
- safety_settings = []
510
- for ss in self.safety_settings_list:
511
- if isinstance(ss, dict):
512
- # Convert dict to types.SafetySetting
513
- safety_settings.append(types.SafetySetting(
514
- category=ss.get('category'),
515
- threshold=ss.get('threshold')
516
- ))
517
- else:
518
- safety_settings.append(ss)
519
- config_dict['safety_settings'] = safety_settings
520
-
521
- # Create the config object
522
- config = types.GenerateContentConfig(**config_dict)
523
-
524
- model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
525
-
526
- api_response = await asyncio.to_thread(
527
- client.models.generate_content,
528
- model=model_path,
529
- contents=contents, # Simplified contents format
530
- config=config # Using config parameter instead of separate generation_config and safety_settings
531
- )
532
-
533
- # Parse response from client.models structure
534
- if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
535
- response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
536
- response_text = "".join(response_text_parts).strip()
537
- else:
538
- # Handle blocked or empty responses
539
- if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
540
- logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
541
- return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
542
- if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
543
- finish_reason = api_response.candidates[0].finish_reason
544
- if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
545
- logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
546
- return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
547
- return "I apologize, but I couldn't generate a response from client.models."
548
- else:
549
- raise ConnectionError("No valid LLM client or model instance available.")
550
-
551
- return response_text.strip()
552
-
553
- except Exception as e:
554
- error_message = str(e).lower()
555
 
556
- # Check if it's a blocked prompt error by examining the error message
557
- if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
558
- logging.error(f"Blocked prompt from LLM: {e}", exc_info=True)
559
- return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}"
 
 
 
 
 
 
 
 
 
560
  else:
561
- logging.error(f"Error in _generate_response: {e}", exc_info=True)
562
- return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
 
565
  def _validate_query(self, query: str) -> bool:
 
425
  If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
426
  """).strip()
427
 
428
+ async def _generate_response(self, current_user_query: str) -> str:
429
+ """
430
+ Generates a response from the LLM based on the current query, system prompts,
431
+ data summaries, RAG context, and the agent's chat history.
432
+ Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
433
+ """
434
+ if not self.is_ready:
435
+ return "Agent is not ready. Please initialize."
436
+ if not client and not self.llm_model_instance:
437
+ return "Error: AI service is not available. Check API configuration."
438
+
439
+ try:
440
+ system_prompt_text = self._build_system_prompt()
441
+ data_summary_text = self._get_dataframes_summary()
442
+ rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25)
443
+
444
+ # Construct the messages for the LLM API call
445
+ llm_messages = []
446
+
447
+ # 1. System-level instructions and context (as a first "user" turn)
448
+ initial_context_prompt = (
449
+ f"{system_prompt_text}\n\n"
450
+ f"## Available Data Overview:\n{data_summary_text}\n\n"
451
+ f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
452
+ f"Given this context, please respond to the user queries that follow in the chat history."
453
+ )
454
+ llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
455
+ # 2. Priming assistant message
456
+ llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
457
+
458
+ # 3. Append the actual conversation history (already includes the current user query)
459
+ for entry in self.chat_history:
460
+ llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
461
+
462
+ # --- Make the API call ---
463
+ response_text = ""
464
+ if self.llm_model_instance: # Standard google-generativeai usage
465
+ logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
466
+
467
+ # Prepare generation config and safety settings for google-generativeai
468
+ gen_config_payload = self.generation_config_dict
469
+ safety_settings_payload = self.safety_settings_list
470
+
471
+ if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
472
+ try:
473
+ gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
474
+ except Exception as e:
475
+ logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
476
+
477
+ api_response = await self.llm_model_instance.generate_content_async(
478
+ contents=llm_messages,
479
+ generation_config=gen_config_payload,
480
+ safety_settings=safety_settings_payload
481
  )
482
+ response_text = api_response.text
 
 
483
 
484
+ elif client: # google.genai client usage
485
+ logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
486
+
487
+ # Convert messages to the format expected by google.genai client
488
+ # The client expects a simpler contents format
489
+ contents = []
490
+ for msg in llm_messages:
491
+ if msg["role"] == "user":
492
+ contents.append(msg["parts"][0]["text"])
493
+ elif msg["role"] == "model":
494
+ # For model responses, we might need to handle differently
495
+ # but for now, let's include them as context
496
+ contents.append(f"Assistant: {msg['parts'][0]['text']}")
497
+
498
+ # Create the config object with both generation config and safety settings
499
+ config_dict = {}
500
+
501
+ # Add generation config parameters
502
+ if self.generation_config_dict:
503
+ for key, value in self.generation_config_dict.items():
504
+ config_dict[key] = value
505
+
506
+ # Add safety settings
507
+ if self.safety_settings_list:
508
+ # Convert safety settings to the correct format if needed
509
+ safety_settings = []
510
+ for ss in self.safety_settings_list:
511
+ if isinstance(ss, dict):
512
+ # Convert dict to types.SafetySetting
513
+ safety_settings.append(types.SafetySetting(
514
+ category=ss.get('category'),
515
+ threshold=ss.get('threshold')
516
+ ))
517
+ else:
518
+ safety_settings.append(ss)
519
+ config_dict['safety_settings'] = safety_settings
520
+
521
+ # Create the config object
522
+ config = types.GenerateContentConfig(**config_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
+ model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
525
+
526
+ api_response = await asyncio.to_thread(
527
+ client.models.generate_content,
528
+ model=model_path,
529
+ contents=contents, # Simplified contents format
530
+ config=config # Using config parameter instead of separate generation_config and safety_settings
531
+ )
532
+
533
+ # Parse response from client.models structure
534
+ if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
535
+ response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
536
+ response_text = "".join(response_text_parts).strip()
537
  else:
538
+ # Handle blocked or empty responses
539
+ if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
540
+ logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
541
+ return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
542
+ if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
543
+ finish_reason = api_response.candidates[0].finish_reason
544
+ if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
545
+ logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
546
+ return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
547
+ return "I apologize, but I couldn't generate a response from client.models."
548
+ else:
549
+ raise ConnectionError("No valid LLM client or model instance available.")
550
+
551
+ return response_text.strip()
552
+
553
+ except Exception as e:
554
+ error_message = str(e).lower()
555
+
556
+ # Check if it's a blocked prompt error by examining the error message
557
+ if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
558
+ logging.error(f"Blocked prompt from LLM: {e}", exc_info=True)
559
+ return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}"
560
+ else:
561
+ logging.error(f"Error in _generate_response: {e}", exc_info=True)
562
+ return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
563
 
564
 
565
  def _validate_query(self, query: str) -> bool: