import os import json from pathlib import Path import faiss import numpy as np import tensorflow as tf from transformers import AutoTokenizer, TFAutoModel from tqdm.auto import tqdm from chatbot_model import ChatbotConfig, EncoderModel from tf_data_pipeline import TFDataPipeline from logger_config import config_logger logger = config_logger(__name__) os.environ["TOKENIZERS_PARALLELISM"] = "false" def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig): """ Perform a quick sanity check to ensure the model is loaded correctly. """ sample_response = "This is a test response." encoded_sample = tokenizer( [sample_response], padding=True, truncation=True, max_length=config.max_context_token_limit, return_tensors='tf' ) # Get embedding sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy() # Check shape if sample_embedding.shape[1] != config.embedding_dim: logger.error( f"Embedding dimension mismatch: Expected {config.embedding_dim}, " f"got {sample_embedding.shape[1]}" ) raise ValueError("Embedding dimension mismatch.") else: logger.info("Embedding dimension matches the configuration.") # Check normalization embedding_norm = np.linalg.norm(sample_embedding, axis=1) if not np.allclose(embedding_norm, 1.0, atol=1e-5): logger.error("Embeddings are not properly normalized.") raise ValueError("Embeddings are not normalized.") else: logger.info("Embeddings are properly normalized.") logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.") def build_faiss_index(): """ Rebuild the FAISS index by: 1) Loading your config.json 2) Initializing encoder + loading submodule & custom weights 3) Loading tokenizer from disk 4) Creating a TFDataPipeline 5) Setting the pipeline's response_pool from a JSON file 6) Using pipeline.compute_and_index_response_embeddings() 7) Saving the FAISS index """ # Directories MODELS_DIR = Path("models") FAISS_DIR = MODELS_DIR / "faiss_indices" FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index" RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json" TOKENIZER_DIR = MODELS_DIR / "tokenizer" SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder" CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5" # 1) Load ChatbotConfig config_path = MODELS_DIR / "config.json" if config_path.exists(): with open(config_path, "r", encoding="utf-8") as f: config_dict = json.load(f) config = ChatbotConfig.from_dict(config_dict) logger.info(f"Loaded ChatbotConfig from {config_path}") else: config = ChatbotConfig() logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.") # 2) Initialize the EncoderModel encoder = EncoderModel(config=config) logger.info("EncoderModel instantiated (empty).") # Overwrite the submodule from 'shared_encoder' directory if SHARED_ENCODER_DIR.exists(): logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...") encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR)) logger.info("Loaded HF submodule into encoder.pretrained.") else: logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.") # Build model once, then load custom weights (projection, etc.) dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32) _ = encoder(dummy_input, training=False) # builds the layers if CUSTOM_WEIGHTS_PATH.exists(): logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}") encoder.load_weights(str(CUSTOM_WEIGHTS_PATH)) logger.info("Custom top-level weights loaded successfully.") else: logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.") # 3) Load tokenizer if TOKENIZER_DIR.exists(): logger.info(f"Loading tokenizer from {TOKENIZER_DIR}") tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR)) else: logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.") tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) #tokenizer.add_special_tokens({'additional_special_tokens': ['']}) # 4) Quick sanity check sanity_check(encoder, tokenizer, config) # 5) Prepare a TFDataPipeline pipeline = TFDataPipeline( config=config, tokenizer=tokenizer, encoder=encoder, index_file_path=str(FAISS_INDEX_PATH), response_pool=[], max_length=config.max_context_token_limit, query_embeddings_cache={}, neg_samples=config.neg_samples, index_type='IndexFlatIP', nlist=100, max_retries=config.max_retries ) # 6) Load the existing response pool if not RESPONSES_PATH.exists(): logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}") raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}") with open(RESPONSES_PATH, "r", encoding="utf-8") as f: response_pool = json.load(f) logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}") pipeline.response_pool = response_pool # assign to pipeline # 7) Build (or rebuild) the FAISS index from pipeline method # This does all the compute-embeddings + index.add in one place logger.info("Starting to compute and index response embeddings via TFDataPipeline...") pipeline.compute_and_index_response_embeddings() # 8) Save the rebuilt FAISS index pipeline.save_faiss_index(str(FAISS_INDEX_PATH)) # Verify loaded_index = faiss.read_index(str(FAISS_INDEX_PATH)) logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.") return loaded_index, pipeline.response_pool if __name__ == "__main__": build_faiss_index()