File size: 37,597 Bytes
e03d275
 
 
 
 
 
514ad52
 
 
 
 
 
 
e03d275
 
003ceb6
514ad52
d350f74
514ad52
 
 
 
 
e03d275
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
69061c0
514ad52
 
 
606d7ff
97bdf15
 
 
 
514ad52
 
606d7ff
97bdf15
 
514ad52
 
 
 
 
 
97bdf15
 
 
 
 
514ad52
97bdf15
 
a5ee064
514ad52
 
 
 
69061c0
 
514ad52
 
 
69061c0
 
 
 
 
514ad52
97bdf15
69061c0
 
514ad52
97bdf15
514ad52
 
69061c0
 
 
 
514ad52
 
 
69061c0
 
 
 
 
514ad52
69061c0
514ad52
69061c0
514ad52
69061c0
514ad52
 
69061c0
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69061c0
514ad52
 
 
 
 
 
 
 
 
e03d275
a5ee064
e03d275
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
69061c0
 
97bdf15
514ad52
 
97bdf15
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69061c0
 
514ad52
 
69061c0
514ad52
69061c0
514ad52
 
69061c0
 
 
 
 
 
514ad52
69061c0
514ad52
69061c0
514ad52
69061c0
514ad52
69061c0
514ad52
69061c0
514ad52
 
 
 
69061c0
514ad52
 
69061c0
 
 
514ad52
69061c0
 
514ad52
 
 
 
 
 
 
 
 
69061c0
514ad52
97bdf15
514ad52
69061c0
514ad52
 
 
69061c0
97bdf15
514ad52
 
 
 
 
 
 
 
 
97bdf15
514ad52
 
 
 
 
 
 
 
 
 
69061c0
97bdf15
514ad52
 
 
 
 
69061c0
 
 
 
514ad52
 
69061c0
 
 
 
514ad52
 
 
 
69061c0
97bdf15
514ad52
97bdf15
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69061c0
514ad52
69061c0
514ad52
69061c0
514ad52
69061c0
efa9136
 
69061c0
514ad52
 
 
 
 
 
 
 
 
 
 
 
69061c0
514ad52
97bdf15
514ad52
 
 
 
97bdf15
514ad52
 
 
 
 
 
 
003ceb6
514ad52
 
 
97bdf15
514ad52
 
 
 
 
 
 
 
 
69061c0
514ad52
 
97bdf15
514ad52
 
 
 
 
 
 
 
97bdf15
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514ad52
ba2acc2
d350f74
ba2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d350f74
ba2acc2
 
 
 
 
 
 
 
 
 
 
d350f74
ba2acc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514ad52
97bdf15
 
514ad52
 
 
 
 
 
97bdf15
 
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97bdf15
 
514ad52
 
 
 
 
 
606d7ff
97bdf15
514ad52
efa9136
514ad52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97bdf15
514ad52
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
import pandas as pd
import json
import os
import asyncio
import logging
import numpy as np
import textwrap # Not used, but kept from original
from datetime import datetime # Not used, but kept from original
from typing import Dict, List, Optional, Union, Any
import traceback

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')

try:
    from google import genai
    from google.genai import types # Assuming this provides necessary types like SafetySetting, HarmCategory etc.
    from google.genai import errors
    # If GenerationConfig or EmbedContentConfig are from a different submodule, adjust imports.
    # For google-generativeai, GenerationConfig is often passed as a dict or genai.types.GenerationConfig
    # and EmbedContentConfig might be implicit or part of task_type.
    GENAI_AVAILABLE = True
    logging.info("Google Generative AI library imported successfully.")
except ImportError:
    logging.warning("Google Generative AI library not found. Please install it: pip install google-generativeai")
    GENAI_AVAILABLE = False
    
    # Dummy classes for graceful degradation (simplified)
    class genai:
        Client = None
        # If using google-generativeai, these would be different:
        # GenerativeModel = None 
        # def configure(*args, **kwargs): pass
        # def embed_content(*args, **kwargs): return {}

    class types: # Placeholder for types used in the original code
        EmbedContentConfig = None # Placeholder
        GenerationConfig = None # Placeholder
        SafetySetting = None
        Candidate = type('Candidate', (), {'FinishReason': type('FinishReason', (), {'STOP': 'STOP'})}) # Dummy for FinishReason

        class HarmCategory:
            HARM_CATEGORY_UNSPECIFIED = "HARM_CATEGORY_UNSPECIFIED"
            HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
            HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
            HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
            HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
            
        class HarmBlockThreshold:
            BLOCK_NONE = "BLOCK_NONE"
            BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
            BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
            BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
        
        class generation_types: # Dummy for BlockedPromptException
            BlockedPromptException = type('BlockedPromptException', (Exception,), {})


# --- Custom Exceptions ---
class ValidationError(Exception):
    """Custom validation error for agent inputs"""
    pass

class RateLimitError(Exception): # Not used, but kept
    """Placeholder for rate limit errors."""
    pass

class AgentNotReadyError(Exception):
    """Agent is not properly initialized"""
    pass

# --- Configuration Constants ---
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY', "")
LLM_MODEL_NAME = "gemini-1.5-flash-latest" # For google-generativeai, model name is directly used.
                                        # For client.models.generate_content, it might need "models/gemini-1.5-flash-latest"
GEMINI_EMBEDDING_MODEL_NAME = "text-embedding-004" # Similarly, might need "models/text-embedding-004"

GENERATION_CONFIG_PARAMS = {
    "temperature": 0.7,
    "top_p": 0.95,
    "top_k": 40,
    "max_output_tokens": 8192, # Ensure this is supported
    "candidate_count": 1,
}

DEFAULT_SAFETY_SETTINGS = [] # User can populate this with {'category': HarmCategory.HARM_CATEGORY_X, 'threshold': HarmBlockThreshold.BLOCK_Y}

# Default RAG documents
DEFAULT_RAG_DOCUMENTS = pd.DataFrame({
    'text': [
        "Employer branding focuses on how an organization is perceived as an employer by potential and current employees.",
        "Key metrics for employer branding include employee engagement, candidate quality, and retention rates.",
        "LinkedIn is a crucial platform for showcasing company culture and attracting talent.",
        "Analyzing follower demographics and post engagement helps refine employer branding strategies.",
        "Content strategy should align with company values to attract the right talent.",
        "Employee advocacy programs can significantly boost employer brand reach and authenticity."
    ]
})

# --- Client Initialization ---
client = None
if GEMINI_API_KEY and GENAI_AVAILABLE:
    try:
        # This is specific. If using google-generativeai, this would be genai.configure(api_key=...)
        client = genai.Client(api_key=GEMINI_API_KEY)
        logging.info("Google GenAI client initialized successfully (using genai.Client).")
    except Exception as e:
        logging.error(f"Failed to initialize Google GenAI client (using genai.Client): {e}")
        client = None
else:
    if not GEMINI_API_KEY:
        logging.warning("GEMINI_API_KEY environment variable not set.")
    if not GENAI_AVAILABLE:
        logging.warning("Google GenAI library not available.")


# --- Utility function to get DataFrame schema representation ---
def get_df_schema_representation(df: pd.DataFrame, df_name: str) -> str:
    """Generates a string representation of a DataFrame's schema and a small sample."""
    if not isinstance(df, pd.DataFrame):
        return f"Item '{df_name}' is not a DataFrame.\n"
    if df.empty:
        return f"DataFrame '{df_name}': Empty\n"
    
    schema_parts = [f"DataFrame '{df_name}':"]
    schema_parts.append(f"  Shape: {df.shape}")
    schema_parts.append("  Columns:")
    for col in df.columns:
        col_type = str(df[col].dtype)
        null_count = df[col].isnull().sum()
        unique_count = df[col].nunique()
        schema_parts.append(f"    - {col} (Type: {col_type}, Nulls: {null_count}/{len(df)}, Uniques: {unique_count})")
    
    if not df.empty:
        schema_parts.append("  Sample Data (first 2 rows):")
        try:
            sample_df_str = df.head(2).to_string(index=True, max_colwidth=50) # Show index for context
            indented_sample_df = "\n".join(["    " + line for line in sample_df_str.split('\n')])
            schema_parts.append(indented_sample_df)
        except Exception as e:
            schema_parts.append(f"    Could not generate sample data: {e}")
            
    return "\n".join(schema_parts) + "\n"

def get_all_schemas_representation(dataframes: Dict[str, pd.DataFrame]) -> str:
    """Generates a string representation of all DataFrame schemas."""
    if not dataframes:
        return "No DataFrames available to the agent."
    
    full_representation = ["=== Available DataFrame Schemas for Analysis ==="]
    for name, df_instance in dataframes.items():
        full_representation.append(get_df_schema_representation(df_instance, name))
    return "\n".join(full_representation)

class AdvancedRAGSystem:
    def __init__(self, documents_df: pd.DataFrame, embedding_model_name: str):
        self.documents_df = documents_df.copy() if not documents_df.empty else DEFAULT_RAG_DOCUMENTS.copy()
        # Ensure 'text' column exists
        if 'text' not in self.documents_df.columns and not self.documents_df.empty:
            logging.warning("'text' column not found in RAG documents. RAG might not work.")
            # Create an empty text column if df is not empty but lacks it, to prevent errors later
            self.documents_df['text'] = ""

        self.embedding_model_name = embedding_model_name # e.g., "models/text-embedding-004" or just "text-embedding-004"
        self.embeddings: Optional[np.ndarray] = None
        self.is_initialized = False
        logging.info(f"AdvancedRAGSystem initialized with {len(self.documents_df)} documents. Model: {self.embedding_model_name}")

    def _embed_single_document_sync(self, text: str) -> Optional[np.ndarray]:
        if not client:
            raise ConnectionError("GenAI client not initialized for RAG embedding.")
        if not text or not isinstance(text, str):
            logging.warning("Cannot embed empty or non-string text for RAG.")
            return None
        
        try:
            # Standard google-generativeai call:
            # embedding_response = genai.embed_content(
            # model=self.embedding_model_name, # e.g., "models/text-embedding-004"
            # content=text,
            # task_type="RETRIEVAL_DOCUMENT" # or "SEMANTIC_SIMILARITY"
            # )
            # return np.array(embedding_response['embedding'])

            # Using the provided client.models.embed_content structure:
            # This might require specific types for config.
            embed_config_payload = None
            if GENAI_AVAILABLE and hasattr(types, 'EmbedContentConfig'): # Assuming types.EmbedContentConfig is relevant
                 # The task_type for EmbedContentConfig might differ, e.g., "SEMANTIC_SIMILARITY" or "RETRIEVAL_DOCUMENT"
                embed_config_payload = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")


            response = client.models.embed_content( # This is the user's original call structure
                model=f"models/{self.embedding_model_name}" if not self.embedding_model_name.startswith("models/") else self.embedding_model_name,
                contents=text, # Original used 'contents', genai.embed_content uses 'content'
                config=embed_config_payload # Original passed 'config'
            )
            
            # Adapt response parsing based on actual client.models.embed_content behavior
            if hasattr(response, 'embeddings') and isinstance(response.embeddings, list) and len(response.embeddings) > 0:
                 # This structure `response.embeddings[0]` seems specific.
                 # Standard genai.embed_content returns a dict `{'embedding': [values]}`
                return np.array(response.embeddings[0]) 
            elif hasattr(response, 'embedding'): # Common for genai.embed_content
                return np.array(response.embedding)
            else:
                logging.error(f"Unexpected embedding response format: {response}")
                return None
        except Exception as e:
            logging.error(f"Error in _embed_single_document_sync for text '{text[:50]}...': {e}", exc_info=True)
            raise

    async def initialize_embeddings(self):
        if self.documents_df.empty or 'text' not in self.documents_df.columns:
            logging.warning("RAG documents DataFrame is empty or lacks 'text' column. Skipping embedding.")
            self.embeddings = np.array([])
            self.is_initialized = True # Initialized, but with no embeddings
            return

        if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')): # Check if standard genai can be used
            logging.error("GenAI client not available for RAG embedding initialization.")
            self.embeddings = np.array([])
            return

        logging.info(f"Starting RAG document embedding for {len(self.documents_df)} documents...")
        embedded_docs_list = []
        
        for index, row in self.documents_df.iterrows():
            text_to_embed = row.get('text', '')
            if not text_to_embed or not isinstance(text_to_embed, str):
                logging.warning(f"Skipping RAG document at index {index} due to invalid/empty text.")
                continue
            
            try:
                # Use asyncio.to_thread for the synchronous embedding call
                embedding_array = await asyncio.to_thread(self._embed_single_document_sync, text_to_embed)
                if embedding_array is not None and embedding_array.size > 0:
                    embedded_docs_list.append(embedding_array)
                else:
                    logging.warning(f"Empty or failed embedding for RAG document at index {index}.")
            except Exception as e:
                logging.error(f"Error embedding RAG document at index {index}: {e}")
                continue # Continue with other documents

        if not embedded_docs_list:
            self.embeddings = np.array([])
            logging.warning("No RAG documents were successfully embedded.")
        else:
            try:
                # Ensure all embeddings have the same shape before vstack
                first_shape = embedded_docs_list[0].shape
                if not all(emb.shape == first_shape for emb in embedded_docs_list):
                    logging.error("Inconsistent embedding shapes found. Cannot stack for RAG.")
                    # Attempt to filter out malformed embeddings if possible, or fail
                    # For now, we'll fail stacking if shapes are inconsistent.
                    self.embeddings = np.array([])
                    return # Exit if shapes are inconsistent

                self.embeddings = np.vstack(embedded_docs_list)
                logging.info(f"Successfully embedded {len(embedded_docs_list)} RAG documents. Embeddings shape: {self.embeddings.shape}")
            except ValueError as ve:
                logging.error(f"Error stacking embeddings (likely due to inconsistent shapes): {ve}")
                self.embeddings = np.array([])
        
        self.is_initialized = True


    def _calculate_cosine_similarity(self, embeddings_matrix: np.ndarray, query_vector: np.ndarray) -> np.ndarray:
        if embeddings_matrix.ndim == 1: # Handle case of single document embedding
            embeddings_matrix = embeddings_matrix.reshape(1, -1)
        if query_vector.ndim == 1:
            query_vector = query_vector.reshape(1, -1)
        
        if embeddings_matrix.size == 0 or query_vector.size == 0:
            return np.array([])
            
        # Normalize embeddings_matrix rows
        norm_matrix = np.linalg.norm(embeddings_matrix, axis=1, keepdims=True)
        # Add a small epsilon to avoid division by zero for zero vectors
        normalized_embeddings_matrix = np.divide(embeddings_matrix, norm_matrix + 1e-8, where=norm_matrix!=0)

        # Normalize query_vector
        norm_query = np.linalg.norm(query_vector, axis=1, keepdims=True)
        normalized_query_vector = np.divide(query_vector, norm_query + 1e-8, where=norm_query!=0)
        
        # Calculate dot product
        return np.dot(normalized_embeddings_matrix, normalized_query_vector.T).flatten()


    async def retrieve_relevant_info(self, query: str, top_k: int = 3, min_similarity: float = 0.3) -> str:
        if not self.is_initialized:
            logging.debug("RAG system not initialized. Cannot retrieve info.")
            return ""
        if self.embeddings is None or self.embeddings.size == 0:
            logging.debug("RAG embeddings not available. Cannot retrieve info.")
            return ""
        if not query or not isinstance(query, str):
            logging.debug("Empty or invalid query for RAG retrieval.")
            return ""

        if not client and not (GENAI_AVAILABLE and os.getenv('GEMINI_API_KEY')):
            logging.error("GenAI client not available for RAG query embedding.")
            return ""

        try:
            query_vector = await asyncio.to_thread(self._embed_single_document_sync, query) # Embed query
            if query_vector is None or query_vector.size == 0:
                logging.warning("Query vector embedding failed or is empty for RAG.")
                return ""

            similarity_scores = self._calculate_cosine_similarity(self.embeddings, query_vector)
            if similarity_scores.size == 0:
                return ""

            relevant_indices = np.where(similarity_scores >= min_similarity)[0]
            if len(relevant_indices) == 0:
                logging.debug(f"No RAG documents met minimum similarity threshold of {min_similarity} for query: '{query[:50]}...'")
                return ""

            # Get scores for relevant documents and sort
            relevant_scores = similarity_scores[relevant_indices]
            # Argsort returns indices to sort relevant_scores; apply to relevant_indices
            sorted_relevant_indices_of_original = relevant_indices[np.argsort(relevant_scores)[::-1]]
            
            top_indices = sorted_relevant_indices_of_original[:top_k]

            context_parts = []
            if 'text' in self.documents_df.columns:
                for i in top_indices:
                    if 0 <= i < len(self.documents_df):
                        context_parts.append(self.documents_df.iloc[i]['text'])
            
            context = "\n\n---\n\n".join(context_parts)
            logging.debug(f"Retrieved RAG context with {len(context_parts)} documents for query: '{query[:50]}...'")
            return context
            
        except Exception as e:
            logging.error(f"Error during RAG retrieval for query '{query[:50]}...': {e}", exc_info=True)
            return ""

class EmployerBrandingAgent:
    def __init__(self,
                 all_dataframes: Optional[Dict[str, pd.DataFrame]] = None,
                 rag_documents_df: Optional[pd.DataFrame] = None,
                 llm_model_name: str = LLM_MODEL_NAME,
                 embedding_model_name: str = GEMINI_EMBEDDING_MODEL_NAME,
                 generation_config_dict: Optional[Dict] = None,
                 safety_settings_list: Optional[List] = None): # safety_settings_list expects list of dicts or SafetySetting objects
        
        self.all_dataframes = {k: v.copy() for k, v in (all_dataframes or {}).items()} # Deep copy
        
        _rag_docs_df = rag_documents_df if rag_documents_df is not None else DEFAULT_RAG_DOCUMENTS.copy()
        self.rag_system = AdvancedRAGSystem(_rag_docs_df, embedding_model_name)
        
        self.llm_model_name = llm_model_name
        self.generation_config_dict = generation_config_dict or GENERATION_CONFIG_PARAMS
        
        # Ensure safety settings are in the correct format if using google-generativeai directly
        self.safety_settings_list = []
        if safety_settings_list and GENAI_AVAILABLE and hasattr(types, 'SafetySetting'):
            for ss_dict in safety_settings_list:
                try:
                    # Assuming ss_dict is like {'category': HarmCategory.XYZ, 'threshold': HarmBlockThreshold.ABC}
                    self.safety_settings_list.append(types.SafetySetting(category=ss_dict['category'], threshold=ss_dict['threshold']))
                except Exception as e:
                    logging.warning(f"Could not convert safety setting dict to SafetySetting object: {ss_dict} - {e}")
        elif safety_settings_list: # If not using types.SafetySetting, pass as is (e.g. for client.models)
             self.safety_settings_list = safety_settings_list


        self.chat_history: List[Dict[str, str]] = [] # Stores {"role": "user/model", "content": "..."}
        self.is_ready = False
        self.llm_model_instance = None # For google-generativeai

        if GENAI_AVAILABLE and client is None and GEMINI_API_KEY: # If client.Client failed but standard genai can be used
            try:
                genai.configure(api_key=GEMINI_API_KEY)
                self.llm_model_instance = genai.GenerativeModel(self.llm_model_name)
                logging.info(f"Initialized GenerativeModel '{self.llm_model_name}' via google-generativeai.")
            except Exception as e:
                logging.error(f"Failed to initialize google-generativeai.GenerativeModel: {e}")
        
        logging.info(f"EmployerBrandingAgent initialized. LLM: {self.llm_model_name}. RAG docs: {len(self.rag_system.documents_df)}. DataFrames: {list(self.all_dataframes.keys())}")

    async def initialize(self) -> bool:
        """Initializes asynchronous components of the agent, primarily RAG embeddings."""
        try:
            if not client and not self.llm_model_instance : # Check if any LLM access is configured
                 logging.error("Cannot initialize agent: GenAI client (client.Client or google.generativeai) not available/configured.")
                 return False
            
            await self.rag_system.initialize_embeddings() # This sets rag_system.is_initialized
            self.is_ready = self.rag_system.is_initialized # Agent is ready if RAG is (even if RAG has no docs)
            logging.info(f"EmployerBrandingAgent.initialize completed. RAG initialized: {self.rag_system.is_initialized}. Agent ready: {self.is_ready}")
            return True
        except Exception as e:
            logging.error(f"Error during EmployerBrandingAgent.initialize: {e}", exc_info=True)
            self.is_ready = False
            return False

    def _get_dataframes_summary(self) -> str:
        return get_all_schemas_representation(self.all_dataframes)

    def _build_system_prompt(self) -> str:
        # This prompt provides overall guidance to the LLM.
        return textwrap.dedent("""
        You are an expert Employer Branding Analyst AI. Your primary function is to analyze LinkedIn data provided (follower statistics, post performance, mentions) and offer actionable insights, data-driven recommendations, and if requested, Python Pandas code snippets for further analysis.

        When providing insights or recommendations:
        - Be specific and base your conclusions on the data summaries and context provided.
        - Structure responses clearly, perhaps using bullet points for key findings or actions.
        - Focus on practical advice that can help improve employer branding efforts.

        When asked to generate Pandas code:
        - Assume the data is available in pandas DataFrames named exactly as in the 'Available DataFrame Schemas' section (e.g., `df_follower_stats`, `df_posts`).
        - Generate executable Python code using pandas.
        - Ensure the code is directly relevant to the user's query and the available data.
        - Briefly explain what the code does.
        - If a query implies data not present in the schemas, state that and do not attempt to fabricate code for it.
        - Do not generate code that modifies DataFrames in place unless explicitly asked. Prefer returning new DataFrames or Series.
        - Handle potential errors in data (e.g., missing values if relevant to the operation) gracefully if simple to do so.
        - Output the code in a single, copy-pasteable block.

        Always refer to the provided DataFrame schemas to understand available columns and data types. Do not hallucinate columns or data.
        If a query is ambiguous or requires data not present, ask for clarification or state the limitation.
        """).strip()

    async def _generate_response(self, current_user_query: str) -> str:
        """
        Generates a response from the LLM based on the current query, system prompts,
        data summaries, RAG context, and the agent's chat history.
        Assumes self.chat_history is already populated by app.py and includes the current_user_query as the last entry.
        """
        if not self.is_ready:
            return "Agent is not ready. Please initialize."
        if not client and not self.llm_model_instance:
            return "Error: AI service is not available. Check API configuration."
    
        try:
            system_prompt_text = self._build_system_prompt()
            data_summary_text = self._get_dataframes_summary()
            rag_context_text = await self.rag_system.retrieve_relevant_info(current_user_query, top_k=2, min_similarity=0.25)
    
            # Construct the messages for the LLM API call
            llm_messages = []
    
            # 1. System-level instructions and context (as a first "user" turn)
            initial_context_prompt = (
                f"{system_prompt_text}\n\n"
                f"## Available Data Overview:\n{data_summary_text}\n\n"
                f"## Relevant Background Information (if any):\n{rag_context_text if rag_context_text else 'No specific background information retrieved for this query.'}\n\n"
                f"Given this context, please respond to the user queries that follow in the chat history."
            )
            llm_messages.append({"role": "user", "parts": [{"text": initial_context_prompt}]})
            # 2. Priming assistant message
            llm_messages.append({"role": "model", "parts": [{"text": "Understood. I have reviewed the context and data overview. I am ready to assist with your Employer Branding analysis based on our conversation."}]})
            
            # 3. Append the actual conversation history (already includes the current user query)
            for entry in self.chat_history:
                llm_messages.append({"role": entry["role"], "parts": [{"text": entry["content"]}]})
    
            # --- Make the API call ---
            response_text = ""
            if self.llm_model_instance:  # Standard google-generativeai usage
                logging.debug(f"Using google-generativeai.GenerativeModel.generate_content_async for LLM call. History length: {len(llm_messages)}")
                
                # Prepare generation config and safety settings for google-generativeai
                gen_config_payload = self.generation_config_dict
                safety_settings_payload = self.safety_settings_list
                
                if GENAI_AVAILABLE and hasattr(types, 'GenerationConfig') and not isinstance(self.generation_config_dict, types.GenerationConfig):
                    try:
                        gen_config_payload = types.GenerationConfig(**self.generation_config_dict)
                    except Exception as e:
                        logging.warning(f"Could not convert gen_config_dict to types.GenerationConfig: {e}")
                
                api_response = await self.llm_model_instance.generate_content_async(
                    contents=llm_messages,
                    generation_config=gen_config_payload,
                    safety_settings=safety_settings_payload
                )
                response_text = api_response.text
                
            elif client:  # google.genai client usage
                logging.debug(f"Using client.models.generate_content for LLM call. History length: {len(llm_messages)}")
                
                # Convert messages to the format expected by google.genai client
                # The client expects a simpler contents format
                contents = []
                for msg in llm_messages:
                    if msg["role"] == "user":
                        contents.append(msg["parts"][0]["text"])
                    elif msg["role"] == "model":
                        # For model responses, we might need to handle differently
                        # but for now, let's include them as context
                        contents.append(f"Assistant: {msg['parts'][0]['text']}")
                
                # Create the config object with both generation config and safety settings
                config_dict = {}
                
                # Add generation config parameters
                if self.generation_config_dict:
                    for key, value in self.generation_config_dict.items():
                        config_dict[key] = value
                
                # Add safety settings
                if self.safety_settings_list:
                    # Convert safety settings to the correct format if needed
                    safety_settings = []
                    for ss in self.safety_settings_list:
                        if isinstance(ss, dict):
                            # Convert dict to types.SafetySetting
                            safety_settings.append(types.SafetySetting(
                                category=ss.get('category'),
                                threshold=ss.get('threshold')
                            ))
                        else:
                            safety_settings.append(ss)
                    config_dict['safety_settings'] = safety_settings
                
                # Create the config object
                config = types.GenerateContentConfig(**config_dict)
                
                model_path = f"models/{self.llm_model_name}" if not self.llm_model_name.startswith("models/") else self.llm_model_name
                
                api_response = await asyncio.to_thread(
                    client.models.generate_content,
                    model=model_path,
                    contents=contents,  # Simplified contents format
                    config=config       # Using config parameter instead of separate generation_config and safety_settings
                )
                
                # Parse response from client.models structure
                if api_response.candidates and api_response.candidates[0].content and api_response.candidates[0].content.parts:
                    response_text_parts = [part.text for part in api_response.candidates[0].content.parts if hasattr(part, 'text')]
                    response_text = "".join(response_text_parts).strip()
                else:
                    # Handle blocked or empty responses
                    if hasattr(api_response, 'prompt_feedback') and api_response.prompt_feedback and api_response.prompt_feedback.block_reason:
                        logging.warning(f"Prompt blocked by client.models: {api_response.prompt_feedback.block_reason}")
                        return f"I'm sorry, your request was blocked. Reason: {api_response.prompt_feedback.block_reason_message or api_response.prompt_feedback.block_reason}"
                    if api_response.candidates and hasattr(api_response.candidates[0], 'finish_reason'):
                        finish_reason = api_response.candidates[0].finish_reason
                        if hasattr(types.Candidate, 'FinishReason') and finish_reason != types.Candidate.FinishReason.STOP:
                            logging.warning(f"Content generation stopped by client.models due to: {finish_reason}. Safety: {getattr(api_response.candidates[0], 'safety_ratings', 'N/A')}")
                            return f"I couldn't complete the response. Reason: {finish_reason}. Please try rephrasing."
                    return "I apologize, but I couldn't generate a response from client.models."
            else:
                raise ConnectionError("No valid LLM client or model instance available.")
    
            return response_text.strip()
    
        except Exception as e:
            error_message = str(e).lower()
            
            # Check if it's a blocked prompt error by examining the error message
            if any(keyword in error_message for keyword in ['blocked', 'safety', 'filter', 'prohibited']):
                logging.error(f"Blocked prompt from LLM: {e}", exc_info=True)
                return f"I'm sorry, your request was blocked by the safety filter. Please rephrase your query. Details: {e}"
            else:
                logging.error(f"Error in _generate_response: {e}", exc_info=True)
                return f"I encountered an error while processing your request: {type(e).__name__} - {str(e)}"


    def _validate_query(self, query: str) -> bool:
        if not query or not isinstance(query, str) or len(query.strip()) < 3:
            logging.warning(f"Invalid query: too short or not a string. Query: '{query}'")
            return False
        if len(query) > 3000: # Increased limit slightly
            logging.warning(f"Invalid query: too long. Length: {len(query)}")
            return False
        return True

    async def process_query(self, user_query: str) -> str:
        """
        Processes the user's query.
        It relies on self.chat_history being set externally (by app.py) to include the full
        conversation context, including the current user_query as the last "user" message.
        This method then calls _generate_response to get the AI's reply.
        It does NOT modify self.chat_history itself; app.py is responsible for that based on Gradio state.
        """
        if not self._validate_query(user_query):
            # This user_query is the one from Gradio input, also the last one in self.chat_history
            return "Please provide a valid query (3 to 3000 characters)."
        
        if not self.is_ready:
            logging.warning("process_query called but agent is not ready. Attempting re-initialization.")
            # This is a fallback. Ideally, initialize is called once and confirmed.
            init_success = await self.initialize()
            if not init_success:
                return "The agent is not properly initialized and could not be started. Please check configuration and logs."
        
        # user_query is the current text from the input box.
        # self.chat_history (set by app.py) should already contain this user_query as the last message.
        # We pass user_query to _generate_response primarily for RAG context retrieval for the current turn.
        response_text = await self._generate_response(user_query)
        return response_text


    def update_dataframes(self, new_dataframes: Dict[str, pd.DataFrame]):
        """Updates the agent's DataFrames. Does not automatically re-initialize RAG or LLM."""
        self.all_dataframes = {k: v.copy() for k, v in new_dataframes.items()} # Deep copy
        logging.info(f"Agent DataFrames updated. Keys: {list(self.all_dataframes.keys())}")
        # Note: If RAG documents depend on these DataFrames, RAG might need re-initialization.
        # For now, RAG uses a static document set.

    def clear_chat_history(self):
        """Clears the agent's internal chat history. App.py should also clear Gradio state."""
        self.chat_history = []
        logging.info("EmployerBrandingAgent internal chat history cleared.")

    def get_status(self) -> Dict[str, Any]:
        return {
            "is_ready": self.is_ready,
            "has_api_key": bool(GEMINI_API_KEY),
            "genai_available": GENAI_AVAILABLE,
            "client_type": "genai.Client" if client else ("google-generativeai" if self.llm_model_instance else "None"),
            "rag_initialized": self.rag_system.is_initialized,
            "num_dataframes": len(self.all_dataframes),
            "dataframe_keys": list(self.all_dataframes.keys()),
            "num_rag_documents": len(self.rag_system.documents_df) if self.rag_system.documents_df is not None else 0,
            "llm_model_name": self.llm_model_name,
            "embedding_model_name": self.embedding_model_name
        }

# --- Functions for Gradio integration (if needed directly, but app.py handles instantiation) ---
def create_agent_instance(dataframes: Optional[Dict[str, pd.DataFrame]] = None,
                          rag_docs: Optional[pd.DataFrame] = None) -> EmployerBrandingAgent:
    logging.info("Creating new EmployerBrandingAgent instance via helper function.")
    return EmployerBrandingAgent(all_dataframes=dataframes, rag_documents_df=rag_docs)

async def initialize_agent_async(agent: EmployerBrandingAgent) -> bool:
    logging.info("Initializing agent via async helper function.")
    return await agent.initialize()


if __name__ == "__main__":
    async def test_agent_logic():
        print("--- Testing Employer Branding Agent ---")
        if not GEMINI_API_KEY:
            print("GEMINI_API_KEY not set. Skipping live API tests.")
            return

        sample_dfs = {
            "followers": pd.DataFrame({'date': pd.to_datetime(['2023-01-01']), 'count': [100]}),
            "posts": pd.DataFrame({'title': ['My first post'], 'likes': [10]})
        }
        
        # Test RAG document loading
        custom_rag = pd.DataFrame({'text': ["Custom RAG context about LinkedIn engagement."]})

        agent = EmployerBrandingAgent(
            all_dataframes=sample_dfs,
            rag_documents_df=custom_rag,
            llm_model_name=LLM_MODEL_NAME,
            embedding_model_name=GEMINI_EMBEDDING_MODEL_NAME
        )
        print("Agent Status (pre-init):", agent.get_status())

        init_success = await agent.initialize()
        print(f"Agent Initialization Success: {init_success}")
        print("Agent Status (post-init):", agent.get_status())

        if not init_success:
            print("Agent initialization failed. Cannot proceed with query test.")
            return

        # Simulate app.py setting history
        test_query1 = "What are the key columns in my followers data?"
        agent.chat_history = [{"role": "user", "content": test_query1}] # app.py would do this
        
        print(f"\nProcessing Query 1: '{test_query1}'")
        response1 = await agent.process_query(user_query=test_query1) # Pass current query for RAG etc.
        print(f"Agent Response 1:\n{response1}")
        
        # Simulate app.py updating history for next turn
        agent.chat_history.append({"role": "model", "content": response1}) 
        
        test_query2 = "Generate pandas code to get the total follower count."
        agent.chat_history.append({"role": "user", "content": test_query2})

        print(f"\nProcessing Query 2: '{test_query2}'")
        response2 = await agent.process_query(user_query=test_query2)
        print(f"Agent Response 2:\n{response2}")

        agent.chat_history.append({"role": "model", "content": response2})
        print("\nFinal Agent Chat History (internal):")
        for item in agent.chat_history:
            print(f"- {item['role']}: {item['content'][:100]}...")
            
        print("\n--- Test Complete ---")

    asyncio.run(test_agent_logic())