GuglielmoTor commited on
Commit
69061c0
·
verified ·
1 Parent(s): efa9136

Update eb_agent_module.py

Browse files
Files changed (1) hide show
  1. eb_agent_module.py +413 -78
eb_agent_module.py CHANGED
@@ -7,102 +7,437 @@ import numpy as np
7
  import textwrap
8
 
9
  try:
10
- from google import genai
11
- from google.genai import types as genai_types
 
12
  except ImportError:
13
- print("Google Generative AI library not found. Please install it: pip install google-generativeai")
14
- # Dummy classes defined here for development/debugging
15
- ... # KEEP YOUR EXISTING DUMMY DEFINITIONS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Configuration
18
- GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
19
- LLM_MODEL_NAME = "gemini-2.0-flash"
20
- GEMINI_EMBEDDING_MODEL_NAME = "gemini-embedding-exp-03-07"
21
 
22
- client = genai.Client(api_key=GEMINI_API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class AdvancedRAGSystem:
 
 
 
25
  def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
26
- self.documents_df = documents_df
27
  self.embedding_model_name = embedding_model_name
28
- self.embeddings = self._embed_documents()
29
-
30
- def _embed_documents(self):
31
- embedded_docs = []
32
- for text in self.documents_df['text']:
33
- response = client.models.embed_content(
34
- model=self.embedding_model_name,
35
- contents=text,
36
- config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
37
- )
38
- embedded_docs.append(np.array(response.embeddings.values))
39
- return np.vstack(embedded_docs)
40
-
41
- def retrieve_relevant_info(self, query: str, top_k=3) -> str:
42
- query_embedding = client.models.embed_content(
43
- model=self.embedding_model_name,
44
- contents=query,
45
- config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
 
 
46
  )
47
- query_vector = np.array(query_embedding.embeddings.values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- scores = np.dot(self.embeddings, query_vector)
50
- top_indices = np.argsort(scores)[-top_k:][::-1]
51
- context = "\n\n".join(self.documents_df.iloc[i]['text'] for i in top_indices)
52
- return context
53
 
54
  class EmployerBrandingAgent:
55
- def __init__(self, all_dataframes: dict, rag_documents_df: pd.DataFrame):
56
- self.all_dataframes = all_dataframes
57
- self.schemas_representation = self._get_all_schemas_representation()
58
- self.chat_history = []
59
- self.rag_system = AdvancedRAGSystem(rag_documents_df, GEMINI_EMBEDDING_MODEL_NAME)
60
- logging.info("EmployerBrandingAgent initialized with Gemini")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def _get_all_schemas_representation(self):
63
- schema_descriptions = []
 
 
 
 
 
 
 
 
 
 
 
 
64
  for key, df in self.all_dataframes.items():
65
- schema = f"DataFrame: df_{key}\nColumns: {', '.join(df.columns)}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  schema_descriptions.append(schema)
67
  return "\n".join(schema_descriptions)
68
 
69
- def _build_prompt(self, user_query: str) -> str:
70
- prompt = f"You are an expert Employer Branding Analyst. Analyze the query based on the following DataFrames.\n"
71
- prompt += self.schemas_representation
72
-
73
- rag_context = self.rag_system.retrieve_relevant_info(user_query)
74
- if rag_context:
75
- prompt += f"\n\nAdditional Context:\n{rag_context}"
76
-
77
- prompt += f"\n\nUser Query:\n{user_query}"
78
- return prompt
79
-
80
- async def process_query(self, user_query: str) -> str:
81
- self.chat_history.append({"role": "user", "content": user_query})
82
- prompt = self._build_prompt(user_query)
83
-
84
- response = client.models.generate_content(
85
- model=LLM_MODEL_NAME,
86
- contents=[prompt],
87
- config=genai_types.GenerateContentConfig(
88
- safety_settings=[
89
- genai_types.SafetySetting(
90
- category=genai_types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
91
- threshold=genai_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
92
- )
93
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
- )
 
 
 
 
 
 
 
 
 
96
 
97
- answer = response.text.strip()
98
- self.chat_history.append({"role": "assistant", "content": answer})
99
- return answer
100
 
101
- def update_dataframes(self, new_dataframes: dict):
102
- self.all_dataframes = new_dataframes
103
- self.schemas_representation = self._get_all_schemas_representation()
104
- logging.info("EmployerBrandingAgent DataFrames updated.")
 
 
 
 
 
 
105
 
106
- def clear_chat_history(self):
 
107
  self.chat_history = []
108
- logging.info("EmployerBrandingAgent chat history cleared.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import textwrap
8
 
9
  try:
10
+ from google import generativeai as genai
11
+ from google.generativeai import types as genai_types # For GenerateContentConfig, SafetySetting etc.
12
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold # Specific enums
13
  except ImportError:
14
+ logging.error("Google Generative AI library not found. Please install it: pip install google-generativeai", exc_info=True)
15
+ # Define dummy classes/variables if import fails, so app.py can try to run
16
+ # (though app.py already has EB_AGENT_AVAILABLE check)
17
+ class genai: Client = None # type: ignore
18
+ class genai_types: # type: ignore
19
+ EmbedContentConfig = None
20
+ GenerateContentConfig = None
21
+ SafetySetting = None
22
+ class HarmCategory: # type: ignore
23
+ HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
24
+ HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
25
+ HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
26
+ HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
27
+ class HarmBlockThreshold: # type: ignore
28
+ BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
29
+ BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
30
+ BLOCK_NONE = "BLOCK_NONE"
31
 
 
 
 
 
32
 
33
+ # --- Configuration Constants ---
34
+ # These are defined here because app.py imports them.
35
+ # User should ensure these are appropriate for their needs.
36
+
37
+ GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
38
+ if not GEMINI_API_KEY:
39
+ logging.warning("GEMINI_API_KEY environment variable not set. EB Agent will not function.")
40
+
41
+ # Model names (as used in app.py imports from this module)
42
+ LLM_MODEL_NAME = "gemini-1.5-flash-latest" # Changed to 1.5-flash as it's generally preferred; user had 2.0-flash. Adjust if needed.
43
+ GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Common embedding model; user had gemini-embedding-exp-03-07. Adjust if needed.
44
+
45
+ # Default Generation Config (app.py imports this as EB_AGENT_GEN_CONFIG)
46
+ GENERATION_CONFIG_PARAMS = {
47
+ "temperature": 0.7,
48
+ "top_p": 0.95,
49
+ "top_k": 40,
50
+ "max_output_tokens": 8192,
51
+ "candidate_count": 1, # Important for non-streaming
52
+ # "stop_sequences": [...] # Optional
53
+ }
54
+
55
+ # Default Safety Settings (app.py imports this as EB_AGENT_SAFETY_SETTINGS)
56
+ DEFAULT_SAFETY_SETTINGS = [
57
+ {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
58
+ {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
59
+ {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
60
+ {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE},
61
+ ]
62
+
63
+ # Placeholder for RAG documents DataFrame (app.py imports this as eb_agent_default_rag_docs)
64
+ # In a real application, this would be loaded from a file or database.
65
+ df_rag_documents = pd.DataFrame({
66
+ 'text': [
67
+ "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
68
+ "Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
69
+ "LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
70
+ "Analyzing follower demographics and post engagement helps refine employer branding strategies."
71
+ ]
72
+ })
73
+
74
+ # --- Client Initialization ---
75
+ # This client will be used by the agent instances.
76
+ # It's initialized once when the module is loaded.
77
+ client = None
78
+ if GEMINI_API_KEY and genai.Client: # Check if genai.Client is not None (due to dummy class on import error)
79
+ try:
80
+ # genai.configure(api_key=GEMINI_API_KEY) # Alternative: global configuration
81
+ client = genai.Client(api_key=GEMINI_API_KEY)
82
+ logging.info("Google GenAI client initialized successfully.")
83
+ except Exception as e:
84
+ logging.error(f"Failed to initialize Google GenAI client: {e}", exc_info=True)
85
+ else:
86
+ logging.warning("Google GenAI client could not be initialized (GEMINI_API_KEY missing or library import failed).")
87
+
88
 
89
  class AdvancedRAGSystem:
90
+ """
91
+ Handles Retrieval Augmented Generation by embedding documents and finding relevant context for queries.
92
+ """
93
  def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
94
+ self.documents_df = documents_df.copy() # Work on a copy
95
  self.embedding_model_name = embedding_model_name
96
+ self.embeddings: np.ndarray | None = None # Populated by async initialize_embeddings
97
+ logging.info(f"AdvancedRAGSystem initialized with embedding model: {self.embedding_model_name}")
98
+
99
+ def _embed_single_document_sync(self, text: str) -> np.ndarray:
100
+ """Synchronous helper to embed a single piece of text."""
101
+ if not client:
102
+ raise ConnectionError("GenAI client not initialized for RAG embedding.")
103
+ if not text or not isinstance(text, str): # Basic validation
104
+ logging.warning("Attempted to embed empty or non-string text. Returning zero vector.")
105
+ # Attempt to get model's embedding dimension, otherwise use a common default (e.g., 768)
106
+ # This is tricky without a live model call. For now, let's assume it will be filtered or handled.
107
+ # If we must return a vector, its dimensionality needs to be known.
108
+ # For simplicity, errors during embedding will be logged and might lead to skipping the doc.
109
+ raise ValueError("Cannot embed empty or non-string text.")
110
+
111
+ # Using client.models.embed_content as per user's provided snippets
112
+ response = client.models.embed_content(
113
+ model=self.embedding_model_name, # e.g., "text-embedding-004" or "gemini-embedding-exp-03-07"
114
+ contents=text, # API takes 'contents' (plural) but can be a single string for single embedding
115
+ config=genai_types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY") if genai_types.EmbedContentConfig else None
116
  )
117
+ # Assuming response.embeddings is the list of floats for a single content string, as per user's snippet.
118
+ return np.array(response.embeddings)
119
+
120
+ async def initialize_embeddings(self):
121
+ """Asynchronously embeds all documents in the documents_df. Should be called once."""
122
+ if self.documents_df.empty:
123
+ logging.info("RAG documents DataFrame is empty. No embeddings to initialize.")
124
+ self.embeddings = np.array([])
125
+ return
126
+
127
+ if not client:
128
+ logging.error("GenAI client not available for RAG embedding initialization.")
129
+ self.embeddings = np.array([])
130
+ return
131
+
132
+ logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
133
+ embedded_docs_list = []
134
+ for index, row in self.documents_df.iterrows():
135
+ text_to_embed = row.get('text')
136
+ if not text_to_embed or not isinstance(text_to_embed, str):
137
+ logging.warning(f"Skipping document at index {index} due to invalid text: {text_to_embed}")
138
+ continue
139
+ try:
140
+ # Wrap the synchronous SDK call in asyncio.to_thread
141
+ embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
142
+ embedded_docs_list.append(embedding_array)
143
+ except Exception as e:
144
+ logging.error(f"Error embedding document text (index {index}) '{str(text_to_embed)[:50]}...': {e}", exc_info=False) # exc_info=False for brevity in loop
145
+
146
+ if not embedded_docs_list:
147
+ self.embeddings = np.array([])
148
+ logging.warning("No documents were successfully embedded for RAG.")
149
+ else:
150
+ try:
151
+ self.embeddings = np.vstack(embedded_docs_list)
152
+ logging.info(f"Successfully embedded {len(embedded_docs_list)} documents for RAG. Embedding matrix shape: {self.embeddings.shape}")
153
+ except ValueError as ve: # Handles cases like empty list or inconsistent shapes if errors weren't caught properly
154
+ logging.error(f"Error stacking embeddings: {ve}. Check individual embedding errors.", exc_info=True)
155
+ self.embeddings = np.array([])
156
+
157
+
158
+ async def retrieve_relevant_info(self, query: str, top_k: int = 3) -> str:
159
+ """Retrieves relevant document snippets for a given query using vector similarity."""
160
+ if self.embeddings is None or self.embeddings.size == 0 or self.documents_df.empty:
161
+ logging.debug("RAG system not initialized or no documents/embeddings available for retrieval.")
162
+ return ""
163
+ if not query or not isinstance(query, str):
164
+ logging.debug("Empty or invalid query for RAG retrieval.")
165
+ return ""
166
+ if not client:
167
+ logging.error("GenAI client not available for RAG query embedding.")
168
+ return ""
169
+
170
+ try:
171
+ query_vector = await asyncio.to_thread(self._embed_single_document_sync, query)
172
+ except Exception as e:
173
+ logging.error(f"Error embedding query '{str(query)[:50]}...': {e}", exc_info=False)
174
+ return ""
175
+
176
+ if query_vector.ndim == 0 or query_vector.size == 0:
177
+ logging.warning(f"Query vector embedding failed or is empty for query: {str(query)[:50]}")
178
+ return ""
179
+ if query_vector.ndim > 1: # Should be 1D
180
+ query_vector = query_vector.flatten()
181
+
182
+ try:
183
+ # Cosine similarity is dot product of normalized vectors.
184
+ # For simplicity, using dot product directly. Normalize if true cosine sim is needed.
185
+ scores = np.dot(self.embeddings, query_vector) # self.embeddings (N, D), query_vector (D,) -> scores (N,)
186
+
187
+ if scores.size == 0:
188
+ return ""
189
+
190
+ actual_top_k = min(top_k, len(self.documents_df), len(scores))
191
+ if actual_top_k <= 0: return "" # Ensure top_k is positive
192
+
193
+ # Get indices of top_k scores in descending order
194
+ top_indices = np.argsort(scores)[-actual_top_k:][::-1]
195
+
196
+ valid_top_indices = [idx for idx in top_indices if 0 <= idx < len(self.documents_df)]
197
+ if not valid_top_indices: return ""
198
+
199
+ # Retrieve the 'text' field from the original DataFrame
200
+ context_parts = [self.documents_df.iloc[i]['text'] for i in valid_top_indices if 'text' in self.documents_df.columns]
201
+ context = "\n\n---\n\n".join(context_parts)
202
+ logging.debug(f"Retrieved RAG context for query '{str(query)[:50]}...':\n{context[:200]}...")
203
+ return context
204
+ except Exception as e:
205
+ logging.error(f"Error during RAG retrieval (dot product/sorting): {e}", exc_info=True)
206
+ return ""
207
 
 
 
 
 
208
 
209
  class EmployerBrandingAgent:
210
+ """
211
+ An agent that uses Generative AI to provide insights on employer branding
212
+ based on provided DataFrames and RAG context.
213
+ """
214
+ def __init__(self,
215
+ all_dataframes: dict,
216
+ rag_documents_df: pd.DataFrame, # For RAG system
217
+ llm_model_name: str,
218
+ embedding_model_name: str, # For RAG system
219
+ generation_config_dict: dict,
220
+ safety_settings_list_of_dicts: list,
221
+ # client_instance, # Using global client for simplicity now
222
+ force_sandbox: bool = False # Parameter from app.py, currently unused here
223
+ ):
224
+ # self.client = client_instance # If client were passed
225
+ self.all_dataframes = {k: df.copy() for k, df in all_dataframes.items()} # Work with copies
226
+ self.schemas_representation = self._get_all_schemas_representation() # Sync method
227
+
228
+ self.chat_history = [] # Stores chat in API format: [{"role": "user/model", "parts": [{"text": "..."}]}]
229
+ # This will be set by app.py before calling process_query
230
 
231
+ self.llm_model_name = llm_model_name
232
+ self.generation_config_dict = generation_config_dict
233
+ self.safety_settings_list_of_dicts = safety_settings_list_of_dicts
234
+
235
+ self.embedding_model_name = embedding_model_name
236
+ self.rag_system = AdvancedRAGSystem(rag_documents_df, self.embedding_model_name)
237
+ # Note: self.rag_system.initialize_embeddings() must be called externally (e.g., in app.py)
238
+
239
+ self.force_sandbox = force_sandbox # Store if needed for tool use later
240
+ logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}, Embedding: {self.embedding_model_name}. RAG system created.")
241
+
242
+ def _get_all_schemas_representation(self) -> str:
243
+ """Generates a string representation of the schemas of all DataFrames."""
244
+ schema_descriptions = ["DataFrames available for analysis:"]
245
  for key, df in self.all_dataframes.items():
246
+ df_name = f"df_{key}" # Consistent naming for the agent to refer to
247
+ columns = ", ".join(df.columns)
248
+ shape = df.shape
249
+ if df.empty:
250
+ schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
251
+ else:
252
+ # Basic stats for numeric columns, first few unique for objects
253
+ sample_info_parts = []
254
+ for col in df.columns:
255
+ if pd.api.types.is_numeric_dtype(df[col]) and not df[col].empty:
256
+ sample_info_parts.append(f"{col} (numeric, e.g., mean: {df[col].mean():.2f})")
257
+ elif pd.api.types.is_datetime64_any_dtype(df[col]) and not df[col].empty:
258
+ sample_info_parts.append(f"{col} (datetime, e.g., min: {df[col].min()}, max: {df[col].max()})")
259
+ elif not df[col].empty:
260
+ unique_vals = df[col].unique()
261
+ display_unique = ', '.join(map(str, unique_vals[:3]))
262
+ if len(unique_vals) > 3: display_unique += ", ..."
263
+ sample_info_parts.append(f"{col} (object, e.g., {display_unique})")
264
+ else:
265
+ sample_info_parts.append(f"{col} (empty)")
266
+
267
+ schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns & Sample Info:\n " + "\n ".join(sample_info_parts))
268
  schema_descriptions.append(schema)
269
  return "\n".join(schema_descriptions)
270
 
271
+ async def _build_prompt_for_current_turn(self, raw_user_query: str) -> str:
272
+ """
273
+ Constructs the full prompt for the current turn, including system instructions,
274
+ DataFrame schemas, RAG context, and the user's query.
275
+ """
276
+ # System instruction part
277
+ prompt_parts = [
278
+ "You are an expert Employer Branding Analyst and a helpful AI assistant. "
279
+ "Your goal is to provide insightful analysis based on the provided LinkedIn data. "
280
+ "When asked to generate Pandas code, ensure it is correct, runnable, and clearly explained. "
281
+ "When providing insights, be specific and refer to the data where possible."
282
+ ]
283
+
284
+ # Schema information
285
+ prompt_parts.append("\n\n--- AVAILABLE DATA ---")
286
+ prompt_parts.append(self.schemas_representation)
287
+
288
+ # RAG context
289
+ if self.rag_system.embeddings is not None and self.rag_system.embeddings.size > 0 : # Check if RAG is initialized
290
+ logging.debug(f"Retrieving RAG context for query: {raw_user_query[:50]}...")
291
+ rag_context = await self.rag_system.retrieve_relevant_info(raw_user_query)
292
+ if rag_context:
293
+ prompt_parts.append("\n\n--- RELEVANT CONTEXTUAL INFORMATION (from documents) ---")
294
+ prompt_parts.append(rag_context)
295
+ else:
296
+ logging.debug("No relevant RAG context found.")
297
+ else:
298
+ logging.debug("RAG system not initialized or embeddings not available, skipping RAG context retrieval.")
299
+
300
+
301
+ # User's current query
302
+ prompt_parts.append("\n\n--- USER REQUEST ---")
303
+ prompt_parts.append(f"Based on all the information above, please respond to the following user query:\n{raw_user_query}")
304
+
305
+ final_prompt = "\n".join(prompt_parts)
306
+ logging.debug(f"Built prompt for current turn (first 300 chars): {final_prompt[:300]}")
307
+ return final_prompt
308
+
309
+ async def process_query(self, raw_user_query_this_turn: str) -> str:
310
+ """
311
+ Processes the user's query, incorporating chat history, DataFrame schemas, and RAG.
312
+ The agent's self.chat_history is expected to be set by the calling application (app.py)
313
+ and should contain the history *before* the current raw_user_query_this_turn.
314
+ This method returns the AI's response string. app.py will then update the agent's
315
+ chat history with the raw_user_query_this_turn and this response.
316
+ """
317
+ if not client:
318
+ logging.error("GenAI client not initialized. Cannot process query.")
319
+ return "Error: The AI Agent is not available due to a configuration issue with the AI service."
320
+
321
+ if not raw_user_query_this_turn.strip():
322
+ return "Please provide a query."
323
+
324
+ # 1. Prepare the augmented prompt for the *current* user query
325
+ # This prompt includes system instructions, schemas, RAG, and the current raw query.
326
+ augmented_current_user_prompt_text = await self._build_prompt_for_current_turn(raw_user_query_this_turn)
327
+
328
+ # 2. Construct the full list of contents for the API call
329
+ # self.chat_history should be in API format: [{"role": "user/model", "parts": [{"text": "..."}]}]
330
+ # It contains history *before* the current raw_user_query_this_turn.
331
+ api_call_contents = []
332
+ if self.chat_history: # Add previous turns if any
333
+ api_call_contents.extend(self.chat_history)
334
+
335
+ # Add the current user turn, using the fully augmented prompt as its content
336
+ api_call_contents.append({"role": "user", "parts": [{"text": augmented_current_user_prompt_text}]})
337
+
338
+ logging.debug(f"Sending to GenAI. Total turns in content: {len(api_call_contents)}")
339
+ if api_call_contents:
340
+ logging.debug(f"Last turn role: {api_call_contents[-1]['role']}, text start: {api_call_contents[-1]['parts'][0]['text'][:100]}")
341
+
342
+
343
+ # 3. Prepare API configuration
344
+ # Convert safety settings from list of dicts to list of SafetySetting objects if genai_types are available
345
+ api_safety_settings = []
346
+ if genai_types.SafetySetting:
347
+ for ss_dict in self.safety_settings_list_of_dicts:
348
+ try:
349
+ api_safety_settings.append(genai_types.SafetySetting(**ss_dict))
350
+ except TypeError: # Handles if HarmCategory/HarmBlockThreshold were strings due to import error
351
+ logging.warning(f"Could not create SafetySetting object from dict: {ss_dict}. Using dict directly.")
352
+ api_safety_settings.append(ss_dict) # Fallback to dict
353
+ else: # genai_types not available
354
+ api_safety_settings = self.safety_settings_list_of_dicts
355
+
356
+
357
+ api_generation_config = None
358
+ if genai_types.GenerateContentConfig:
359
+ try:
360
+ api_generation_config = genai_types.GenerateContentConfig(
361
+ **self.generation_config_dict,
362
+ safety_settings=api_safety_settings # This should be list of SafetySetting objects or dicts
363
+ )
364
+ except TypeError:
365
+ logging.warning("Could not create GenerateContentConfig object. Using dicts directly for config.")
366
+ # Fallback: if GenerateContentConfig fails, try to pass dicts (might not be supported by client.models.generate_content's 'config' param)
367
+ # The user's snippet uses config=types.GenerateContentConfig(...), so this object is important.
368
+ # If it fails, the call might fail.
369
+ api_generation_config = self.generation_config_dict # This is not ideal for the 'config' parameter.
370
+ # The 'config' parameter of client.models.generate_content expects a GenerateContentConfig object.
371
+ # If we can't create it, we should signal an error or try a different call structure if available.
372
+ # For now, proceed and let the API call potentially fail if config is malformed.
373
+ # A better fallback would be to construct the config parts individually if the main object fails.
374
+ # However, the user's snippet is clear: config=types.GenerateContentConfig(...)
375
+ # So, if genai_types.GenerateContentConfig is None, this will be an issue.
376
+
377
+ else: # genai_types.GenerateContentConfig is None (likely import error)
378
+ logging.error("genai_types.GenerateContentConfig not available. Cannot form API config.")
379
+ return "Error: AI Agent configuration problem (GenerateContentConfig type missing)."
380
+
381
+
382
+ # 4. Make the API call (synchronous SDK call wrapped in asyncio.to_thread)
383
+ try:
384
+ response = await asyncio.to_thread(
385
+ client.models.generate_content, # As per user's snippet
386
+ model=self.llm_model_name,
387
+ contents=api_call_contents,
388
+ config=api_generation_config # Pass the GenerateContentConfig object
389
  )
390
+ # Extract text. User's snippet uses response.text
391
+ # Check for blocked content or other issues
392
+ if not response.candidates:
393
+ block_reason = response.prompt_feedback.block_reason if response.prompt_feedback else "Unknown"
394
+ logging.warning(f"AI response blocked or empty. Reason: {block_reason}")
395
+ # You might want to inspect response.prompt_feedback for block reasons
396
+ error_message = f"The AI's response was blocked. Reason: {block_reason}."
397
+ if response.prompt_feedback and response.prompt_feedback.block_reason_message:
398
+ error_message += f" Details: {response.prompt_feedback.block_reason_message}"
399
+ return error_message
400
 
401
+ answer = response.text.strip()
402
+ logging.info(f"Successfully received AI response (first 100 chars): {answer[:100]}")
 
403
 
404
+ except Exception as e:
405
+ logging.error(f"Error during GenAI call: {e}", exc_info=True)
406
+ # Check if it's a Google specific API error for more details
407
+ # from google.api_core import exceptions as google_exceptions
408
+ # if isinstance(e, google_exceptions.GoogleAPIError):
409
+ # answer = f"API Error: {e.message}"
410
+ # else:
411
+ answer = f"# Error during AI processing:\n{type(e).__name__}: {str(e)}"
412
+
413
+ return answer
414
 
415
+ def clear_chat_history(self): # This method is called by app.py
416
+ """Clears the agent's internal chat history."""
417
  self.chat_history = []
418
+ logging.info("EmployerBrandingAgent chat history cleared by request.")
419
+
420
+ # --- Module-level function for schema display in app.py ---
421
+ def get_all_schemas_representation(all_dataframes: dict) -> str:
422
+ """
423
+ Generates a string representation of the schemas of all DataFrames,
424
+ intended for display in the Gradio UI.
425
+ This is a standalone function as it's imported directly by app.py.
426
+ """
427
+ if not all_dataframes:
428
+ return "No DataFrames are currently loaded."
429
+
430
+ schema_descriptions = ["DataFrames currently available in the application state:"]
431
+ for key, df in all_dataframes.items():
432
+ df_name = f"df_{key}"
433
+ columns = ", ".join(df.columns)
434
+ shape = df.shape
435
+ if df.empty:
436
+ schema = f"\n--- DataFrame: {df_name} ---\nStatus: Empty\nShape: {shape}\nColumns: {columns}"
437
+ else:
438
+ # Provide a bit more detail for UI display
439
+ sample_data_str = df.head(2).to_markdown(index=False) # Use markdown for better UI rendering
440
+ schema = (f"\n--- DataFrame: {df_name} ---\nShape: {shape}\nColumns: {columns}\n\n<details><summary>Sample Data (first 2 rows of {df_name}):</summary>\n\n{sample_data_str}\n\n</details>")
441
+ schema_descriptions.append(schema)
442
+ return "\n".join(schema_descriptions)
443
+