File size: 6,495 Bytes
71ca212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# import os
# import json
# from pathlib import Path
# import faiss
# import numpy as np
# import tensorflow as tf
# from transformers import AutoTokenizer, TFAutoModel
# from tqdm.auto import tqdm
# from chatbot_model import ChatbotConfig, EncoderModel
# from tf_data_pipeline import TFDataPipeline
# from logger_config import config_logger
# logger = config_logger(__name__)
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
# """
# Perform a quick sanity check to ensure the model is loaded correctly.
# """
# sample_response = "This is a test response."
# encoded_sample = tokenizer(
# [sample_response],
# padding=True,
# truncation=True,
# max_length=config.max_context_token_limit,
# return_tensors='tf'
# )
# # Get embedding
# sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
# # Check shape
# if sample_embedding.shape[1] != config.embedding_dim:
# logger.error(
# f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
# f"got {sample_embedding.shape[1]}"
# )
# raise ValueError("Embedding dimension mismatch.")
# else:
# logger.info("Embedding dimension matches the configuration.")
# # Check normalization
# embedding_norm = np.linalg.norm(sample_embedding, axis=1)
# if not np.allclose(embedding_norm, 1.0, atol=1e-5):
# logger.error("Embeddings are not properly normalized.")
# raise ValueError("Embeddings are not normalized.")
# else:
# logger.info("Embeddings are properly normalized.")
# logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")
# def build_faiss_index():
# """
# Rebuild the FAISS index by:
# 1) Loading your config.json
# 2) Initializing encoder + loading submodule & custom weights
# 3) Loading tokenizer from disk
# 4) Creating a TFDataPipeline
# 5) Setting the pipeline's response_pool from a JSON file
# 6) Using pipeline.compute_and_index_response_embeddings()
# 7) Saving the FAISS index
# """
# # Directories
# MODELS_DIR = Path("models")
# FAISS_DIR = MODELS_DIR / "faiss_indices"
# FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
# RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
# TOKENIZER_DIR = MODELS_DIR / "tokenizer"
# SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
# CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
# # 1) Load ChatbotConfig
# config_path = MODELS_DIR / "config.json"
# if config_path.exists():
# with open(config_path, "r", encoding="utf-8") as f:
# config_dict = json.load(f)
# config = ChatbotConfig.from_dict(config_dict)
# logger.info(f"Loaded ChatbotConfig from {config_path}")
# else:
# config = ChatbotConfig()
# logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
# # 2) Initialize the EncoderModel
# encoder = EncoderModel(config=config)
# logger.info("EncoderModel instantiated (empty).")
# # Overwrite the submodule from 'shared_encoder' directory
# if SHARED_ENCODER_DIR.exists():
# logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
# encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
# logger.info("Loaded HF submodule into encoder.pretrained.")
# else:
# logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
# # Build model once, then load custom weights (projection, etc.)
# dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
# _ = encoder(dummy_input, training=False) # builds the layers
# if CUSTOM_WEIGHTS_PATH.exists():
# logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
# encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
# logger.info("Custom top-level weights loaded successfully.")
# else:
# logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
# # 3) Load tokenizer
# if TOKENIZER_DIR.exists():
# logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
# tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
# else:
# logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
# tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
# # 4) Quick sanity check
# sanity_check(encoder, tokenizer, config)
# # 5) Prepare a TFDataPipeline
# pipeline = TFDataPipeline(
# config=config,
# tokenizer=tokenizer,
# encoder=encoder,
# index_file_path=str(FAISS_INDEX_PATH),
# response_pool=[],
# max_length=config.max_context_token_limit,
# query_embeddings_cache={},
# neg_samples=config.neg_samples,
# index_type='IndexFlatIP',
# nlist=100,
# max_retries=config.max_retries
# )
# # 6) Load the existing response pool
# if not RESPONSES_PATH.exists():
# logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
# raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
# with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
# response_pool = json.load(f)
# logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
# pipeline.response_pool = response_pool # assign to pipeline
# # 7) Build (or rebuild) the FAISS index from pipeline method
# # This does all the compute-embeddings + index.add in one place
# logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
# pipeline.compute_and_index_response_embeddings()
# # 8) Save the rebuilt FAISS index
# pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
# # Verify
# loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
# logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
# return loaded_index, pipeline.response_pool
# if __name__ == "__main__":
# build_faiss_index()
|