csc525_retrieval_based_chatbot / unused /build_faiss_index.py
JoeArmani
restructuring
71ca212
raw
history blame
6.5 kB
# import os
# import json
# from pathlib import Path
# import faiss
# import numpy as np
# import tensorflow as tf
# from transformers import AutoTokenizer, TFAutoModel
# from tqdm.auto import tqdm
# from chatbot_model import ChatbotConfig, EncoderModel
# from tf_data_pipeline import TFDataPipeline
# from logger_config import config_logger
# logger = config_logger(__name__)
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
# """
# Perform a quick sanity check to ensure the model is loaded correctly.
# """
# sample_response = "This is a test response."
# encoded_sample = tokenizer(
# [sample_response],
# padding=True,
# truncation=True,
# max_length=config.max_context_token_limit,
# return_tensors='tf'
# )
# # Get embedding
# sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
# # Check shape
# if sample_embedding.shape[1] != config.embedding_dim:
# logger.error(
# f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
# f"got {sample_embedding.shape[1]}"
# )
# raise ValueError("Embedding dimension mismatch.")
# else:
# logger.info("Embedding dimension matches the configuration.")
# # Check normalization
# embedding_norm = np.linalg.norm(sample_embedding, axis=1)
# if not np.allclose(embedding_norm, 1.0, atol=1e-5):
# logger.error("Embeddings are not properly normalized.")
# raise ValueError("Embeddings are not normalized.")
# else:
# logger.info("Embeddings are properly normalized.")
# logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
# def build_faiss_index():
# """
# Rebuild the FAISS index by:
# 1) Loading your config.json
# 2) Initializing encoder + loading submodule & custom weights
# 3) Loading tokenizer from disk
# 4) Creating a TFDataPipeline
# 5) Setting the pipeline's response_pool from a JSON file
# 6) Using pipeline.compute_and_index_response_embeddings()
# 7) Saving the FAISS index
# """
# # Directories
# MODELS_DIR = Path("models")
# FAISS_DIR = MODELS_DIR / "faiss_indices"
# FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
# RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
# TOKENIZER_DIR = MODELS_DIR / "tokenizer"
# SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
# CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
# # 1) Load ChatbotConfig
# config_path = MODELS_DIR / "config.json"
# if config_path.exists():
# with open(config_path, "r", encoding="utf-8") as f:
# config_dict = json.load(f)
# config = ChatbotConfig.from_dict(config_dict)
# logger.info(f"Loaded ChatbotConfig from {config_path}")
# else:
# config = ChatbotConfig()
# logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
# # 2) Initialize the EncoderModel
# encoder = EncoderModel(config=config)
# logger.info("EncoderModel instantiated (empty).")
# # Overwrite the submodule from 'shared_encoder' directory
# if SHARED_ENCODER_DIR.exists():
# logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
# encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
# logger.info("Loaded HF submodule into encoder.pretrained.")
# else:
# logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
# # Build model once, then load custom weights (projection, etc.)
# dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
# _ = encoder(dummy_input, training=False) # builds the layers
# if CUSTOM_WEIGHTS_PATH.exists():
# logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
# encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
# logger.info("Custom top-level weights loaded successfully.")
# else:
# logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
# # 3) Load tokenizer
# if TOKENIZER_DIR.exists():
# logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
# tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
# else:
# logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
# tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
# # 4) Quick sanity check
# sanity_check(encoder, tokenizer, config)
# # 5) Prepare a TFDataPipeline
# pipeline = TFDataPipeline(
# config=config,
# tokenizer=tokenizer,
# encoder=encoder,
# index_file_path=str(FAISS_INDEX_PATH),
# response_pool=[],
# max_length=config.max_context_token_limit,
# query_embeddings_cache={},
# neg_samples=config.neg_samples,
# index_type='IndexFlatIP',
# nlist=100,
# max_retries=config.max_retries
# )
# # 6) Load the existing response pool
# if not RESPONSES_PATH.exists():
# logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
# raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
# with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
# response_pool = json.load(f)
# logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
# pipeline.response_pool = response_pool # assign to pipeline
# # 7) Build (or rebuild) the FAISS index from pipeline method
# # This does all the compute-embeddings + index.add in one place
# logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
# pipeline.compute_and_index_response_embeddings()
# # 8) Save the rebuilt FAISS index
# pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
# # Verify
# loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
# logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
# return loaded_index, pipeline.response_pool
# if __name__ == "__main__":
# build_faiss_index()