File size: 6,316 Bytes
7a0020b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import json
from pathlib import Path

import faiss
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tqdm.auto import tqdm

from chatbot_model import ChatbotConfig, EncoderModel
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger

logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def sanity_check(encoder: EncoderModel, tokenizer: AutoTokenizer, config: ChatbotConfig):
    """
    Perform a quick sanity check to ensure the model is loaded correctly.
    """
    sample_response = "This is a test response."
    encoded_sample = tokenizer(
        [sample_response],
        padding=True,
        truncation=True,
        max_length=config.max_context_token_limit,
        return_tensors='tf'
    )
    
    # Get embedding
    sample_embedding = encoder(encoded_sample['input_ids'], training=False).numpy()
    
    # Check shape
    if sample_embedding.shape[1] != config.embedding_dim:
        logger.error(
            f"Embedding dimension mismatch: Expected {config.embedding_dim}, "
            f"got {sample_embedding.shape[1]}"
        )
        raise ValueError("Embedding dimension mismatch.")
    else:
        logger.info("Embedding dimension matches the configuration.")
    
    # Check normalization
    embedding_norm = np.linalg.norm(sample_embedding, axis=1)
    if not np.allclose(embedding_norm, 1.0, atol=1e-5):
        logger.error("Embeddings are not properly normalized.")
        raise ValueError("Embeddings are not normalized.")
    else:
        logger.info("Embeddings are properly normalized.")
    
    logger.info("Sanity check passed: Model loaded correctly and outputs are as expected.")

def build_faiss_index():
    """
    Rebuild the FAISS index by:
      1) Loading your config.json
      2) Initializing encoder + loading submodule & custom weights
      3) Loading tokenizer from disk
      4) Creating a TFDataPipeline
      5) Setting the pipeline's response_pool from a JSON file
      6) Using pipeline.compute_and_index_response_embeddings()
      7) Saving the FAISS index
    """
    # Directories
    MODELS_DIR = Path("models")
    FAISS_DIR = MODELS_DIR / "faiss_indices"
    FAISS_INDEX_PATH = FAISS_DIR / "faiss_index_production.index"
    RESPONSES_PATH = FAISS_DIR / "faiss_index_production_responses.json"
    TOKENIZER_DIR = MODELS_DIR / "tokenizer"
    SHARED_ENCODER_DIR = MODELS_DIR / "shared_encoder"
    CUSTOM_WEIGHTS_PATH = MODELS_DIR / "encoder_custom_weights.weights.h5"
    
    # 1) Load ChatbotConfig
    config_path = MODELS_DIR / "config.json"
    if config_path.exists():
        with open(config_path, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
        logger.info(f"Loaded ChatbotConfig from {config_path}")
    else:
        config = ChatbotConfig()
        logger.warning(f"No config.json found at {config_path}. Using default ChatbotConfig.")
    
    # 2) Initialize the EncoderModel
    encoder = EncoderModel(config=config)
    logger.info("EncoderModel instantiated (empty).")
    
    # Overwrite the submodule from 'shared_encoder' directory
    if SHARED_ENCODER_DIR.exists():
        logger.info(f"Loading DistilBERT submodule from {SHARED_ENCODER_DIR}...")
        encoder.pretrained = TFAutoModel.from_pretrained(str(SHARED_ENCODER_DIR))
        logger.info("Loaded HF submodule into encoder.pretrained.")
    else:
        logger.warning(f"No shared_encoder directory at {SHARED_ENCODER_DIR}. Using default pretrained model.")
    
    # Build model once, then load custom weights (projection, etc.)
    dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
    _ = encoder(dummy_input, training=False)  # builds the layers
    
    if CUSTOM_WEIGHTS_PATH.exists():
        logger.info(f"Loading custom top-level weights from {CUSTOM_WEIGHTS_PATH}")
        encoder.load_weights(str(CUSTOM_WEIGHTS_PATH))
        logger.info("Custom top-level weights loaded successfully.")
    else:
        logger.warning(f"Custom weights file not found at {CUSTOM_WEIGHTS_PATH}.")
    
    # 3) Load tokenizer
    if TOKENIZER_DIR.exists():
        logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
        tokenizer = AutoTokenizer.from_pretrained(str(TOKENIZER_DIR))
    else:
        logger.warning(f"No tokenizer dir at {TOKENIZER_DIR}, falling back to default HF tokenizer.")
        tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
        #tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
    
    # 4) Quick sanity check
    sanity_check(encoder, tokenizer, config)
    
    # 5) Prepare a TFDataPipeline
    pipeline = TFDataPipeline(
        config=config,
        tokenizer=tokenizer,
        encoder=encoder,
        index_file_path=str(FAISS_INDEX_PATH),
        response_pool=[],
        max_length=config.max_context_token_limit,
        query_embeddings_cache={},
        neg_samples=config.neg_samples,
        index_type='IndexFlatIP',
        nlist=100,
        max_retries=config.max_retries
    )
    
    # 6) Load the existing response pool
    if not RESPONSES_PATH.exists():
        logger.error(f"Response pool JSON file not found at {RESPONSES_PATH}")
        raise FileNotFoundError(f"No response pool JSON at {RESPONSES_PATH}")
    
    with open(RESPONSES_PATH, "r", encoding="utf-8") as f:
        response_pool = json.load(f)
    logger.info(f"Loaded {len(response_pool)} responses from {RESPONSES_PATH}")
    
    pipeline.response_pool = response_pool  # assign to pipeline
    
    # 7) Build (or rebuild) the FAISS index from pipeline method
    #    This does all the compute-embeddings + index.add in one place
    logger.info("Starting to compute and index response embeddings via TFDataPipeline...")
    pipeline.compute_and_index_response_embeddings()
    
    # 8) Save the rebuilt FAISS index
    pipeline.save_faiss_index(str(FAISS_INDEX_PATH))
    
    # Verify
    loaded_index = faiss.read_index(str(FAISS_INDEX_PATH))
    logger.info(f"Verified the rebuilt FAISS index has {loaded_index.ntotal} vectors.")
    
    return loaded_index, pipeline.response_pool

if __name__ == "__main__":
    build_faiss_index()