import os import sys import faiss import json import pickle import tensorflow as tf from transformers import AutoTokenizer, TFAutoModel from tqdm.auto import tqdm from pathlib import Path from chatbot_model import ChatbotConfig, EncoderModel from tf_data_pipeline import TFDataPipeline from logger_config import config_logger logger = config_logger(__name__) os.environ["TOKENIZERS_PARALLELISM"] = "false" def main(): MODELS_DIR = 'new_iteration/data_prep_iterative_models' PROCESSED_DATA_DIR = 'new_iteration/processed_outputs' CACHE_DIR = 'new_iteration/cache' TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer') FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices') TF_RECORD_DIR = 'new_iteration/training_data' FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json') CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl') TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord') # Decide whether to load the **custom** model or base DistilBERT (Base used for first iteration). # True for custom, False for base DistilBERT. LOAD_CUSTOM_MODEL = True NUM_NEG_SAMPLES = 10 # Ensure output directories exist os.makedirs(MODELS_DIR, exist_ok=True) os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(TOKENIZER_DIR, exist_ok=True) os.makedirs(FAISS_INDICES_DIR, exist_ok=True) os.makedirs(TF_RECORD_DIR, exist_ok=True) # Init config config_json = Path(MODELS_DIR) / "config.json" if config_json.exists(): with open(config_json, "r", encoding="utf-8") as f: config_dict = json.load(f) config = ChatbotConfig.from_dict(config_dict) logger.info(f"Loaded ChatbotConfig from {config_json}") else: config = ChatbotConfig() logger.warning("No config.json found. Using default ChatbotConfig.") # Ensure negative samples are set config.neg_samples = NUM_NEG_SAMPLES # Load or init tokenizer try: if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()): logger.info(f"Loading tokenizer from {TOKENIZER_DIR}") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR) else: logger.info(f"Loading base tokenizer for {config.pretrained_model}") tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True) tokenizer.save_pretrained(TOKENIZER_DIR) logger.info(f"New tokenizer saved to {TOKENIZER_DIR}") except Exception as e: logger.error(f"Failed to load or create tokenizer: {e}") sys.exit(1) # Init the encoder try: encoder = EncoderModel(config=config) logger.info("EncoderModel initialized successfully.") if LOAD_CUSTOM_MODEL: # Load the DistilBERT submodule from 'shared_encoder' shared_encoder_path = Path(MODELS_DIR) / "shared_encoder" if shared_encoder_path.exists(): logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}") encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path) else: logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.") # Load custom .weights.h5 (projection, dropout, etc.) custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5" if custom_weights_path.exists(): logger.info(f"Loading custom top-level weights from {custom_weights_path}") # Dummy forward pass forces model build to ensure all layers are built dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32) _ = encoder(dummy_input, training=False) encoder.load_weights(str(custom_weights_path)) logger.info("Custom encoder weights loaded successfully.") else: logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.") else: # Base DistilBERT with special tokens logger.info("Using the base DistilBERT without loading custom weights.") # Resize token embeddings in case we added special tokens (EncoderModel class) encoder.pretrained.resize_token_embeddings(len(tokenizer)) logger.info(f"Token embeddings resized to: {len(tokenizer)}") except Exception as e: logger.error(f"Failed to initialize EncoderModel: {e}") sys.exit(1) # Load JSON dialogues try: if not Path(JSON_TRAINING_DATA_PATH).exists(): logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.") dialogues = [] else: dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None) logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.") except Exception as e: logger.error(f"Failed to load dialogues: {e}") sys.exit(1) # Load or init query_embeddings_cache. NOTE: recompute after each training. This was a bug source. query_embeddings_cache = {} if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, 'rb') as f: query_embeddings_cache = pickle.load(f) logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.") except Exception as e: logger.warning(f"Failed to load query embeddings cache: {e}") else: logger.info("No existing query embeddings cache found. Starting fresh.") # Initialize TFDataPipeline try: # Load or init FAISS index if Path(FAISS_INDEX_PRODUCTION_PATH).exists(): logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...") faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH) logger.info("FAISS index loaded successfully.") else: logger.info("No existing FAISS index found. Initializing a new index.") dimension = config.embedding_dim # Ensure this matches your encoder's output faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity logger.info(f"Initialized new FAISS index with dimension {dimension}.") # Init TFDataPipeline with the FAISS index data_pipeline = TFDataPipeline( config=config, tokenizer=tokenizer, encoder=encoder, index_file_path=FAISS_INDEX_PRODUCTION_PATH, response_pool=[], max_length=config.max_context_token_limit, neg_samples=config.neg_samples, query_embeddings_cache=query_embeddings_cache, index_type='IndexFlatIP', nlist=100, # Not used for IndexFlatIP. Retained for future use of IndexIVFFlat max_retries=config.max_retries ) logger.info("TFDataPipeline initialized successfully.") except Exception as e: logger.error(f"Failed to initialize TFDataPipeline: {e}") sys.exit(1) # Collect response pool from dialogues try: if dialogues: response_pool = data_pipeline.collect_responses_with_domain(dialogues) data_pipeline.response_pool = response_pool logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.") else: logger.warning("No dialogues loaded. response_pool remains empty.") except Exception as e: logger.error(f"Failed to collect responses: {e}") sys.exit(1) # Build FAISS index with response embeddings try: if data_pipeline.response_pool: data_pipeline.build_text_to_domain_map() logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...") data_pipeline.compute_and_index_response_embeddings() logger.info("Response embeddings computed and added to FAISS index.") # Save the FAISS index data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH) # Also save response pool JSON response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json') with open(response_pool_path, 'w', encoding='utf-8') as f: json.dump(data_pipeline.response_pool, f, indent=2) logger.info(f"Response pool saved to {response_pool_path}.") else: logger.warning("No responses to embed. Skipping FAISS indexing.") except Exception as e: logger.error(f"Failed to compute or add response embeddings: {e}") sys.exit(1) # Prepare training data as TFRecords (TensforFlow Record format) try: if dialogues: logger.info("Starting data preparation and saving as TFRecord...") data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH) logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.") else: logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.") except Exception as e: logger.error(f"Failed during data preparation and saving: {e}") sys.exit(1) # Save query embeddings cache try: with open(CACHE_FILE, 'wb') as f: pickle.dump(data_pipeline.query_embeddings_cache, f) logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.") except Exception as e: logger.error(f"Failed to save query embeddings cache: {e}") sys.exit(1) # Save Tokenizer try: tokenizer.save_pretrained(TOKENIZER_DIR) logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.") except Exception as e: logger.error(f"Failed to save tokenizer: {e}") sys.exit(1) logger.info("Data preparation pipeline completed successfully.") if __name__ == "__main__": main()