# import os | |
# import json | |
# from pathlib import Path | |
# import faiss | |
# import numpy as np | |
# import tensorflow as tf | |
# from transformers import AutoTokenizer, TFAutoModel | |
# from tqdm.auto import tqdm | |
# from chatbot_model import ChatbotConfig, EncoderModel | |
# from tf_data_pipeline import TFDataPipeline | |
# from logger_config import config_logger | |
# logger = config_logger(__name__) | |
# os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig): | |
# """ | |
# Perform a quick sanity check to ensure the model is loaded correctly. | |
# """ | |
# sample_response = "This is a test response." | |
# encoded_sample = tokenizer( | |
# [sample_response], | |
# padding=True, | |
# truncation=True, | |
# max_length=config.max_context_token_limit, | |
# return_tensors='tf' | |
# ) | |
# # Get embedding | |
# sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy() | |
# # Check shape | |
# if sample_embedding.shape[1] != config.embedding_dim: | |
# logger.error( | |
# f"Embedding dimension mismatch: Expected {config.embedding_dim}, " | |
# f"got {sample_embedding.shape[1]}" | |
# ) | |
# raise ValueError("Embedding dimension mismatch.") | |
# else: | |
# logger.info("Embedding dimension matches the configuration.") | |
# # Check normalization | |
# embedding_norm = np.linalg.norm(sample_embedding, axis=1) | |
# if not np.allclose(embedding_norm, 1.0, atol=1e-5): | |
# logger.error("Embeddings are not properly normalized.") | |
# raise ValueError("Embeddings are not normalized.") | |
# else: | |
# logger.info("Embeddings are properly normalized.") | |
# logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.") | |
# def build_faiss_index(): | |
# """ | |
# Rebuild the FAISS index by: | |
# 1) Loading your config.json | |
# 2) Initializing encoder + loading submodule & custom weights | |
# 3) Loading tokenizer from disk | |
# 4) Creating a TFDataPipeline | |
# 5) Setting the pipeline's response_pool from a JSON file | |
# 6) Using pipeline.compute_and_index_response_embeddings() | |
# 7) Saving the FAISS index | |
# """ | |
# # Directories | |
# MODELS_DIR = Path("models") | |
# FAISS_DIR = MODELS_DIR / "faiss_indices" | |
# FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index" | |
# RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json" | |
# TOKENIZER_DIR = MODELS_DIR / "tokenizer" | |
# SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder" | |
# CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5" | |
# # 1) Load ChatbotConfig | |
# config_path = MODELS_DIR / "config.json" | |
# if config_path.exists(): | |
# with open(config_path, "r", encoding="utf-8") as f: | |
# config_dict = json.load(f) | |
# config = ChatbotConfig.from_dict(config_dict) | |
# logger.info(f"Loaded ChatbotConfig from {config_path}") | |
# else: | |
# config = ChatbotConfig() | |
# logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.") | |
# # 2) Initialize the EncoderModel | |
# encoder = EncoderModel(config=config) | |
# logger.info("EncoderModel instantiated (empty).") | |
# # Overwrite the submodule from 'shared_encoder' directory | |
# if SHARED_ENCODER_DIR.exists(): | |
# logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...") | |
# encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR)) | |
# logger.info("Loaded HF submodule into encoder.pretrained.") | |
# else: | |
# logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.") | |
# # Build model once, then load custom weights (projection, etc.) | |
# dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32) | |
# _ = encoder(dummy_input, training=False) # builds the layers | |
# if CUSTOM_WEIGHTS_PATH.exists(): | |
# logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}") | |
# encoder.load_weights(str(CUSTOM_WEIGHTS_PATH)) | |
# logger.info("Custom top-level weights loaded successfully.") | |
# else: | |
# logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.") | |
# # 3) Load tokenizer | |
# if TOKENIZER_DIR.exists(): | |
# logger.info(f"Loading tokenizer from {TOKENIZER_DIR}") | |
# tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR)) | |
# else: | |
# logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.") | |
# tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) | |
# # 4) Quick sanity check | |
# sanity_check(encoder, tokenizer, config) | |
# # 5) Prepare a TFDataPipeline | |
# pipeline = TFDataPipeline( | |
# config=config, | |
# tokenizer=tokenizer, | |
# encoder=encoder, | |
# index_file_path=str(FAISS_INDEX_PATH), | |
# response_pool=[], | |
# max_length=config.max_context_token_limit, | |
# query_embeddings_cache={}, | |
# neg_samples=config.neg_samples, | |
# index_type='IndexFlatIP', | |
# nlist=100, | |
# max_retries=config.max_retries | |
# ) | |
# # 6) Load the existing response pool | |
# if not RESPONSES_PATH.exists(): | |
# logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}") | |
# raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}") | |
# with open(RESPONSES_PATH, "r", encoding="utf-8") as f: | |
# response_pool = json.load(f) | |
# logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}") | |
# pipeline.response_pool = response_pool # assign to pipeline | |
# # 7) Build (or rebuild) the FAISS index from pipeline method | |
# # This does all the compute-embeddings + index.add in one place | |
# logger.info("Starting to compute and index response embeddings via TFDataPipeline...") | |
# pipeline.compute_and_index_response_embeddings() | |
# # 8) Save the rebuilt FAISS index | |
# pipeline.save_faiss_index(str(FAISS_INDEX_PATH)) | |
# # Verify | |
# loaded_index = faiss.read_index(str(FAISS_INDEX_PATH)) | |
# logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.") | |
# return loaded_index, pipeline.response_pool | |
# if __name__ == "__main__": | |
# build_faiss_index() | |