csc525_retrieval_based_chatbot / build_faiss_index.py
JoeArmani
updates - new iteration with type token
7a0020b
raw
history blame
6.32 kB
import os
import json
from pathlib import Path
import faiss
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tqdm.auto import tqdm
from chatbot_model import ChatbotConfig, EncoderModel
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
"""
Perform a quick sanity check to ensure the model is loaded correctly.
"""
sample_response = "This is a test response."
encoded_sample = tokenizer(
[sample_response],
padding=True,
truncation=True,
max_length=config.max_context_token_limit,
return_tensors='tf'
)
# Get embedding
sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
# Check shape
if sample_embedding.shape[1] != config.embedding_dim:
logger.error(
f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
f"got {sample_embedding.shape[1]}"
)
raise ValueError("Embedding dimension mismatch.")
else:
logger.info("Embedding dimension matches the configuration.")
# Check normalization
embedding_norm = np.linalg.norm(sample_embedding, axis=1)
if not np.allclose(embedding_norm, 1.0, atol=1e-5):
logger.error("Embeddings are not properly normalized.")
raise ValueError("Embeddings are not normalized.")
else:
logger.info("Embeddings are properly normalized.")
logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
def build_faiss_index():
"""
Rebuild the FAISS index by:
1) Loading your config.json
2) Initializing encoder + loading submodule & custom weights
3) Loading tokenizer from disk
4) Creating a TFDataPipeline
5) Setting the pipeline's response_pool from a JSON file
6) Using pipeline.compute_and_index_response_embeddings()
7) Saving the FAISS index
"""
# Directories
MODELS_DIR = Path("models")
FAISS_DIR = MODELS_DIR / "faiss_indices"
FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
TOKENIZER_DIR = MODELS_DIR / "tokenizer"
SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
# 1) Load ChatbotConfig
config_path = MODELS_DIR / "config.json"
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_path}")
else:
config = ChatbotConfig()
logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
# 2) Initialize the EncoderModel
encoder = EncoderModel(config=config)
logger.info("EncoderModel instantiated (empty).")
# Overwrite the submodule from 'shared_encoder' directory
if SHARED_ENCODER_DIR.exists():
logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
logger.info("Loaded HF submodule into encoder.pretrained.")
else:
logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
# Build model once, then load custom weights (projection, etc.)
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
_ = encoder(dummy_input, training=False) # builds the layers
if CUSTOM_WEIGHTS_PATH.exists():
logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
logger.info("Custom top-level weights loaded successfully.")
else:
logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
# 3) Load tokenizer
if TOKENIZER_DIR.exists():
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
else:
logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
#tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
# 4) Quick sanity check
sanity_check(encoder, tokenizer, config)
# 5) Prepare a TFDataPipeline
pipeline = TFDataPipeline(
config=config,
tokenizer=tokenizer,
encoder=encoder,
index_file_path=str(FAISS_INDEX_PATH),
response_pool=[],
max_length=config.max_context_token_limit,
query_embeddings_cache={},
neg_samples=config.neg_samples,
index_type='IndexFlatIP',
nlist=100,
max_retries=config.max_retries
)
# 6) Load the existing response pool
if not RESPONSES_PATH.exists():
logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
response_pool = json.load(f)
logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
pipeline.response_pool = response_pool # assign to pipeline
# 7) Build (or rebuild) the FAISS index from pipeline method
# This does all the compute-embeddings + index.add in one place
logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
pipeline.compute_and_index_response_embeddings()
# 8) Save the rebuilt FAISS index
pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
# Verify
loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
return loaded_index, pipeline.response_pool
if __name__ == "__main__":
build_faiss_index()