GuglielmoTor commited on
Commit
d350f74
·
verified ·
1 Parent(s): caf3e08

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +135 -90
eb_agent_module.py CHANGED
@@ -15,6 +15,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
15
  try:
16
  from google import genai
17
  from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc.
 
18
  # If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
19
  # For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
20
  # and EmbedContentConfig might be implicit or part of task_type.
@@ -424,97 +425,141 @@ class EmployerBrandingAgent:
424
  If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
425
  """).strip()
426
 
427
- async def _generate_response(self, current_user_query: str) -> str:
428
- """
429
- Generates a response from the LLM based on the current query, system prompts,
430
- data summaries, RAG context, and the agent's chat history.
431
- Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
432
- """
433
- if not self.is_ready:
434
- return "Agent is not ready. Please initialize."
435
- if not client and not self.llm_model_instance:
436
- return "Error: AI service is not available. Check API configuration."
437
-
438
- try:
439
- system_prompt_text = self._build_system_prompt()
440
- data_summary_text = self._get_dataframes_summary()
441
- rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25) # Fine-tuned RAG params
442
-
443
- # Construct the messages for the LLM API call
444
- # The history (self.chat_history) is set by app.py and includes the current user query.
445
- llm_messages = []
446
-
447
- # 1. System-level instructions and context (as a first "user" turn)
448
- initial_context_prompt = (
449
- f"{system_prompt_text}\n\n"
450
- f"## Available Data Overview:\n{data_summary_text}\n\n"
451
- f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
452
- f"Given this context, please respond to the user queries that follow in the chat history."
453
- )
454
- llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
455
- # 2. Priming assistant message
456
- llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
457
-
458
- # 3. Append the actual conversation history (already includes the current user query)
459
- for entry in self.chat_history: # self.chat_history is set by app.py
460
- llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
461
-
462
- # Prepare generation config and safety settings for the API
463
- gen_config_payload = self.generation_config_dict
464
- safety_settings_payload = self.safety_settings_list # Already formatted if types.SafetySetting used
465
-
466
- if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
467
- try:
468
- gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
469
- except Exception as e:
470
- logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
471
-
472
-
473
- # --- Make the API call ---
474
- response_text = ""
475
- if self.llm_model_instance: # Standard google-generativeai usage
476
- logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
477
- api_response = await self.llm_model_instance.generate_content_async(
478
- contents=llm_messages,
479
- generation_config=gen_config_payload,
480
- safety_settings=safety_settings_payload
481
- )
482
- response_text = api_response.text # Simplification, assumes single part text response
483
- elif client: # User's original client.models.generate_content structure
484
- logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
485
- # This call needs to be async or wrapped, asyncio.to_thread is used as in original
486
- model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
487
- api_response = await asyncio.to_thread(
488
- client.models.generate_content,
489
- model=model_path,
490
- contents=llm_messages,
491
- generation_config=gen_config_payload, # Ensure this is the correct type for client.models
492
- safety_settings=safety_settings_payload # Ensure this is the correct type
493
  )
494
- # Parse response from client.models structure
495
- if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
496
- response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
497
- response_text = "".join(response_text_parts).strip()
498
- else: # Handle blocked or empty responses from client.models
499
- if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
500
- logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
501
- return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
502
- if api_response.candidates and api_response.candidates[0].finish_reason != types.Candidate.FinishReason.STOP: # Assuming types.Candidate.FinishReason.STOP is valid
503
- logging.warning(f"Content generation stopped by client.models due to: {api_response.candidates[0].finish_reason}. Safety: {api_response.candidates[0].safety_ratings if hasattr(api_response.candidates[0], 'safety_ratings') else 'N/A'}")
504
- return f"I couldn't complete the response. Reason: {api_response.candidates[0].finish_reason}. Please try rephrasing."
505
- return "I apologize, but I couldn't generate a response from client.models."
506
-
507
- else:
508
- raise ConnectionError("No valid LLM client or model instance available.")
509
-
510
- return response_text.strip()
511
-
512
- except types.BlockedPromptException as bpe: # <--- CORRECTED LINE
513
- logging.error(f"BlockedPromptException from LLM: {bpe}", exc_info=True) # Keep exc_info=True for full traceback in logs
514
- return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {bpe}"
515
- except Exception as e:
516
- logging.error(f"Error in _generate_response: {e}", exc_info=True) # Keep exc_info=True
517
- return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
 
520
  def _validate_query(self, query: str) -> bool:
 
15
  try:
16
  from google import genai
17
  from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc.
18
+ from google.genai import errors
19
  # If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
20
  # For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
21
  # and EmbedContentConfig might be implicit or part of task_type.
 
425
  If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
426
  """).strip()
427
 
428
+ async def _generate_response(self, current_user_query: str) -> str:
429
+ """
430
+ Generates a response from the LLM based on the current query, system prompts,
431
+ data summaries, RAG context, and the agent's chat history.
432
+ Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
433
+ """
434
+ if not self.is_ready:
435
+ return "Agent is not ready. Please initialize."
436
+ if not client and not self.llm_model_instance:
437
+ return "Error: AI service is not available. Check API configuration."
438
+
439
+ try:
440
+ system_prompt_text = self._build_system_prompt()
441
+ data_summary_text = self._get_dataframes_summary()
442
+ rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25)
443
+
444
+ # Construct the messages for the LLM API call
445
+ llm_messages = []
446
+
447
+ # 1. System-level instructions and context (as a first "user" turn)
448
+ initial_context_prompt = (
449
+ f"{system_prompt_text}\n\n"
450
+ f"## Available Data Overview:\n{data_summary_text}\n\n"
451
+ f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
452
+ f"Given this context, please respond to the user queries that follow in the chat history."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  )
454
+ llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
455
+ # 2. Priming assistant message
456
+ llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
457
+
458
+ # 3. Append the actual conversation history (already includes the current user query)
459
+ for entry in self.chat_history:
460
+ llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
461
+
462
+ # --- Make the API call ---
463
+ response_text = ""
464
+ if self.llm_model_instance: # Standard google-generativeai usage
465
+ logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
466
+
467
+ # Prepare generation config and safety settings for google-generativeai
468
+ gen_config_payload = self.generation_config_dict
469
+ safety_settings_payload = self.safety_settings_list
470
+
471
+ if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
472
+ try:
473
+ gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
474
+ except Exception as e:
475
+ logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
476
+
477
+ api_response = await self.llm_model_instance.generate_content_async(
478
+ contents=llm_messages,
479
+ generation_config=gen_config_payload,
480
+ safety_settings=safety_settings_payload
481
+ )
482
+ response_text = api_response.text
483
+
484
+ elif client: # google.genai client usage
485
+ logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
486
+
487
+ # Convert messages to the format expected by google.genai client
488
+ # The client expects a simpler contents format
489
+ contents = []
490
+ for msg in llm_messages:
491
+ if msg["role"] == "user":
492
+ contents.append(msg["parts"][0]["text"])
493
+ elif msg["role"] == "model":
494
+ # For model responses, we might need to handle differently
495
+ # but for now, let's include them as context
496
+ contents.append(f"Assistant: {msg['parts'][0]['text']}")
497
+
498
+ # Create the config object with both generation config and safety settings
499
+ config_dict = {}
500
+
501
+ # Add generation config parameters
502
+ if self.generation_config_dict:
503
+ for key, value in self.generation_config_dict.items():
504
+ config_dict[key] = value
505
+
506
+ # Add safety settings
507
+ if self.safety_settings_list:
508
+ # Convert safety settings to the correct format if needed
509
+ safety_settings = []
510
+ for ss in self.safety_settings_list:
511
+ if isinstance(ss, dict):
512
+ # Convert dict to types.SafetySetting
513
+ safety_settings.append(types.SafetySetting(
514
+ category=ss.get('category'),
515
+ threshold=ss.get('threshold')
516
+ ))
517
+ else:
518
+ safety_settings.append(ss)
519
+ config_dict['safety_settings'] = safety_settings
520
+
521
+ # Create the config object
522
+ config = types.GenerateContentConfig(**config_dict)
523
+
524
+ model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
525
+
526
+ api_response = await asyncio.to_thread(
527
+ client.models.generate_content,
528
+ model=model_path,
529
+ contents=contents, # Simplified contents format
530
+ config=config # Using config parameter instead of separate generation_config and safety_settings
531
+ )
532
+
533
+ # Parse response from client.models structure
534
+ if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
535
+ response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
536
+ response_text = "".join(response_text_parts).strip()
537
+ else:
538
+ # Handle blocked or empty responses
539
+ if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
540
+ logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
541
+ return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
542
+ if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
543
+ finish_reason = api_response.candidates[0].finish_reason
544
+ if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
545
+ logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
546
+ return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
547
+ return "I apologize, but I couldn't generate a response from client.models."
548
+ else:
549
+ raise ConnectionError("No valid LLM client or model instance available.")
550
+
551
+ return response_text.strip()
552
+
553
+ except Exception as e:
554
+ error_message = str(e).lower()
555
+
556
+ # Check if it's a blocked prompt error by examining the error message
557
+ if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
558
+ logging.error(f"Blocked prompt from LLM: {e}", exc_info=True)
559
+ return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}"
560
+ else:
561
+ logging.error(f"Error in _generate_response: {e}", exc_info=True)
562
+ return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"
563
 
564
 
565
  def _validate_query(self, query: str) -> bool: