import os import sys import faiss import json import pickle import tensorflow as tf from transformers import AutoTokenizer, TFAutoModel from tqdm.auto import tqdm from pathlib import Path # Your existing modules from chatbot_model import ChatbotConfig, EncoderModel from tf_data_pipeline import TFDataPipeline from logger_config import config_logger logger = config_logger(__name__) os.environ["TOKENIZERS_PARALLELISM"] = "false" def main(): # Constants MODELS_DIR = 'new_iteration/data_prep_iterative_models' PROCESSED_DATA_DIR = 'new_iteration/processed_outputs' CACHE_DIR = 'new_iteration/cache' TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer') FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices') TF_RECORD_DIR = 'new_iteration/training_data' FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json') CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl') TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord') # Decide whether to load the **custom** fine-tuned model or just base DistilBERT. # True for custom, False for base DistilBERT. LOAD_CUSTOM_MODEL = True NUM_NEG_SAMPLES = 10 # Ensure output directories exist os.makedirs(MODELS_DIR, exist_ok=True) os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(TOKENIZER_DIR, exist_ok=True) os.makedirs(FAISS_INDICES_DIR, exist_ok=True) os.makedirs(TF_RECORD_DIR, exist_ok=True) # Initialize config config_json = Path(MODELS_DIR) / "config.json" if config_json.exists(): with open(config_json, "r", encoding="utf-8") as f: config_dict = json.load(f) config = ChatbotConfig.from_dict(config_dict) logger.info(f"Loaded ChatbotConfig from {config_json}") else: config = ChatbotConfig() logger.warning("No config.json found. Using default ChatbotConfig.") config.neg_samples = NUM_NEG_SAMPLES # Load or initialize tokenizer try: # If the directory has a valid tokenizer if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()): logger.info(f"Loading tokenizer from {TOKENIZER_DIR}") tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR) else: # Initialize from base DistilBERT logger.info(f"Loading base tokenizer for {config.pretrained_model}") tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) # Save to disk Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True) tokenizer.save_pretrained(TOKENIZER_DIR) logger.info(f"New tokenizer saved to {TOKENIZER_DIR}") except Exception as e: logger.error(f"Failed to load or create tokenizer: {e}") sys.exit(1) # Initialize the encoder try: encoder = EncoderModel(config=config) logger.info("EncoderModel initialized successfully.") if LOAD_CUSTOM_MODEL: # Load the DistilBERT submodule from 'shared_encoder' shared_encoder_path = Path(MODELS_DIR) / "shared_encoder" if shared_encoder_path.exists(): logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}") encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path) else: logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.") # Load top-level custom .weights.h5 (projection, dropout, etc.) custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5" if custom_weights_path.exists(): logger.info(f"Loading custom top-level weights from {custom_weights_path}") # Build model layers with a dummy forward pass dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32) _ = encoder(dummy_input, training=False) encoder.load_weights(str(custom_weights_path)) logger.info("Custom encoder weights loaded successfully.") else: logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.") else: # Just base DistilBERT with special tokens resized logger.info("Using the base DistilBERT without loading custom weights.") # Resize token embeddings in case we added special tokens encoder.pretrained.resize_token_embeddings(len(tokenizer)) logger.info(f"Token embeddings resized to: {len(tokenizer)}") except Exception as e: logger.error(f"Failed to initialize EncoderModel: {e}") sys.exit(1) # Load JSON dialogues try: if not Path(JSON_TRAINING_DATA_PATH).exists(): logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.") dialogues = [] else: dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None) logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.") except Exception as e: logger.error(f"Failed to load dialogues: {e}") sys.exit(1) # Load or initialize query_embeddings_cache query_embeddings_cache = {} if os.path.exists(CACHE_FILE): try: with open(CACHE_FILE, 'rb') as f: query_embeddings_cache = pickle.load(f) logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.") except Exception as e: logger.warning(f"Failed to load query embeddings cache: {e}") else: logger.info("No existing query embeddings cache found. Starting fresh.") # Initialize TFDataPipeline try: # Determine if FAISS index should be loaded or initialized if Path(FAISS_INDEX_PRODUCTION_PATH).exists(): # Load existing index logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...") faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH) logger.info("FAISS index loaded successfully.") else: # Initialize a new FAISS index logger.info("No existing FAISS index found. Initializing a new index.") dimension = config.embedding_dim # Ensure this matches your encoder's output faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity logger.info(f"Initialized new FAISS index with dimension {dimension}.") # Initialize TFDataPipeline with the FAISS index data_pipeline = TFDataPipeline( config=config, tokenizer=tokenizer, encoder=encoder, index_file_path=FAISS_INDEX_PRODUCTION_PATH, response_pool=[], max_length=config.max_context_token_limit, neg_samples=config.neg_samples, query_embeddings_cache=query_embeddings_cache, index_type='IndexFlatIP', nlist=100, max_retries=config.max_retries ) logger.info("TFDataPipeline initialized successfully.") except Exception as e: logger.error(f"Failed to initialize TFDataPipeline: {e}") sys.exit(1) # 7) Collect unique assistant responses from dialogues try: if dialogues: response_pool = data_pipeline.collect_responses_with_domain(dialogues) data_pipeline.response_pool = response_pool logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.") else: logger.warning("No dialogues loaded. response_pool remains empty.") except Exception as e: logger.error(f"Failed to collect responses: {e}") sys.exit(1) # 8) Build the FAISS index with response embeddings # Instead of manually computing embeddings, we use the pipeline method try: if data_pipeline.response_pool: data_pipeline.build_text_to_domain_map() logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...") data_pipeline.compute_and_index_response_embeddings() logger.info("Response embeddings computed and added to FAISS index.") # Save the updated FAISS index data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH) # Also save the response pool JSON response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json') with open(response_pool_path, 'w', encoding='utf-8') as f: json.dump(data_pipeline.response_pool, f, indent=2) logger.info(f"Response pool saved to {response_pool_path}.") else: logger.warning("No responses to embed. Skipping FAISS indexing.") except Exception as e: logger.error(f"Failed to compute or add response embeddings: {e}") sys.exit(1) # 9) Prepare and save training data as TFRecords try: if dialogues: logger.info("Starting data preparation and saving as TFRecord...") data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH) logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.") else: logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.") except Exception as e: logger.error(f"Failed during data preparation and saving: {e}") sys.exit(1) # 10) Save query embeddings cache try: with open(CACHE_FILE, 'wb') as f: pickle.dump(data_pipeline.query_embeddings_cache, f) logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.") except Exception as e: logger.error(f"Failed to save query embeddings cache: {e}") sys.exit(1) # Save Tokenizer try: tokenizer.save_pretrained(TOKENIZER_DIR) logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.") except Exception as e: logger.error(f"Failed to save tokenizer: {e}") sys.exit(1) logger.info("Data preparation pipeline completed successfully.") if __name__ == "__main__": main()