|
import os |
|
import json |
|
from pathlib import Path |
|
|
|
import faiss |
|
import numpy as np |
|
import tensorflow as tf |
|
from transformers import AutoTokenizer, TFAutoModel |
|
from tqdm.auto import tqdm |
|
|
|
from chatbot_model import ChatbotConfig, EncoderModel |
|
from tf_data_pipeline import TFDataPipeline |
|
from logger_config import config_logger |
|
|
|
logger = config_logger(__name__) |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig): |
|
""" |
|
Perform a quick sanity check to ensure the model is loaded correctly. |
|
""" |
|
sample_response = "This is a test response." |
|
encoded_sample = tokenizer( |
|
[sample_response], |
|
padding=True, |
|
truncation=True, |
|
max_length=config.max_context_token_limit, |
|
return_tensors='tf' |
|
) |
|
|
|
|
|
sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy() |
|
|
|
|
|
if sample_embedding.shape[1] != config.embedding_dim: |
|
logger.error( |
|
f"Embedding dimension mismatch: Expected {config.embedding_dim}, " |
|
f"got {sample_embedding.shape[1]}" |
|
) |
|
raise ValueError("Embedding dimension mismatch.") |
|
else: |
|
logger.info("Embedding dimension matches the configuration.") |
|
|
|
|
|
embedding_norm = np.linalg.norm(sample_embedding, axis=1) |
|
if not np.allclose(embedding_norm, 1.0, atol=1e-5): |
|
logger.error("Embeddings are not properly normalized.") |
|
raise ValueError("Embeddings are not normalized.") |
|
else: |
|
logger.info("Embeddings are properly normalized.") |
|
|
|
logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.") |
|
|
|
def build_faiss_index(): |
|
""" |
|
Rebuild the FAISS index by: |
|
1) Loading your config.json |
|
2) Initializing encoder + loading submodule & custom weights |
|
3) Loading tokenizer from disk |
|
4) Creating a TFDataPipeline |
|
5) Setting the pipeline's response_pool from a JSON file |
|
6) Using pipeline.compute_and_index_response_embeddings() |
|
7) Saving the FAISS index |
|
""" |
|
|
|
MODELS_DIR = Path("models") |
|
FAISS_DIR = MODELS_DIR / "faiss_indices" |
|
FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index" |
|
RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json" |
|
TOKENIZER_DIR = MODELS_DIR / "tokenizer" |
|
SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder" |
|
CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5" |
|
|
|
|
|
config_path = MODELS_DIR / "config.json" |
|
if config_path.exists(): |
|
with open(config_path, "r", encoding="utf-8") as f: |
|
config_dict = json.load(f) |
|
config = ChatbotConfig.from_dict(config_dict) |
|
logger.info(f"Loaded ChatbotConfig from {config_path}") |
|
else: |
|
config = ChatbotConfig() |
|
logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.") |
|
|
|
|
|
encoder = EncoderModel(config=config) |
|
logger.info("EncoderModel instantiated (empty).") |
|
|
|
|
|
if SHARED_ENCODER_DIR.exists(): |
|
logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...") |
|
encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR)) |
|
logger.info("Loaded HF submodule into encoder.pretrained.") |
|
else: |
|
logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.") |
|
|
|
|
|
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32) |
|
_ = encoder(dummy_input, training=False) |
|
|
|
if CUSTOM_WEIGHTS_PATH.exists(): |
|
logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}") |
|
encoder.load_weights(str(CUSTOM_WEIGHTS_PATH)) |
|
logger.info("Custom top-level weights loaded successfully.") |
|
else: |
|
logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.") |
|
|
|
|
|
if TOKENIZER_DIR.exists(): |
|
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}") |
|
tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR)) |
|
else: |
|
logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.") |
|
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) |
|
|
|
|
|
|
|
sanity_check(encoder, tokenizer, config) |
|
|
|
|
|
pipeline = TFDataPipeline( |
|
config=config, |
|
tokenizer=tokenizer, |
|
encoder=encoder, |
|
index_file_path=str(FAISS_INDEX_PATH), |
|
response_pool=[], |
|
max_length=config.max_context_token_limit, |
|
query_embeddings_cache={}, |
|
neg_samples=config.neg_samples, |
|
index_type='IndexFlatIP', |
|
nlist=100, |
|
max_retries=config.max_retries |
|
) |
|
|
|
|
|
if not RESPONSES_PATH.exists(): |
|
logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}") |
|
raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}") |
|
|
|
with open(RESPONSES_PATH, "r", encoding="utf-8") as f: |
|
response_pool = json.load(f) |
|
logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}") |
|
|
|
pipeline.response_pool = response_pool |
|
|
|
|
|
|
|
logger.info("Starting to compute and index response embeddings via TFDataPipeline...") |
|
pipeline.compute_and_index_response_embeddings() |
|
|
|
|
|
pipeline.save_faiss_index(str(FAISS_INDEX_PATH)) |
|
|
|
|
|
loaded_index = faiss.read_index(str(FAISS_INDEX_PATH)) |
|
logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.") |
|
|
|
return loaded_index, pipeline.response_pool |
|
|
|
if __name__ == "__main__": |
|
build_faiss_index() |
|
|