File size: 10,468 Bytes
74af405
 
 
5b413d1
74af405
7a0020b
 
74af405
7a0020b
74af405
 
 
 
 
 
 
 
 
7a0020b
 
 
74af405
 
7a0020b
74af405
7a0020b
74af405
7a0020b
 
71ca212
7a0020b
71ca212
7a0020b
74af405
 
 
 
 
 
 
 
 
71ca212
7a0020b
 
 
 
 
 
 
 
 
 
71ca212
7a0020b
74af405
71ca212
74af405
7a0020b
 
 
 
 
 
 
 
 
 
74af405
7a0020b
74af405
 
71ca212
74af405
 
 
7a0020b
 
 
 
 
 
 
 
 
 
71ca212
7a0020b
 
 
71ca212
 
7a0020b
 
71ca212
7a0020b
 
 
 
 
71ca212
7a0020b
 
71ca212
74af405
 
7a0020b
74af405
5b413d1
74af405
 
 
 
7a0020b
 
 
 
 
 
74af405
 
 
 
71ca212
7a0020b
 
 
74af405
 
 
7a0020b
 
 
 
74af405
 
 
71ca212
7a0020b
 
 
 
 
 
 
 
 
 
71ca212
74af405
 
 
 
7a0020b
74af405
 
 
 
5b413d1
71ca212
74af405
 
 
 
 
 
 
71ca212
74af405
7a0020b
 
 
 
 
 
74af405
 
 
 
71ca212
74af405
7a0020b
 
 
 
 
 
71ca212
7a0020b
 
71ca212
7a0020b
 
 
 
 
 
74af405
 
7a0020b
74af405
 
71ca212
74af405
7a0020b
 
 
 
 
 
74af405
 
 
 
71ca212
74af405
 
 
 
 
 
 
 
7a0020b
74af405
 
 
 
 
 
 
 
7a0020b
 
74af405
7a0020b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import sys
import faiss
import json
import pickle
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tqdm.auto import tqdm
from pathlib import Path
from chatbot_model import ChatbotConfig, EncoderModel
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger

logger = config_logger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def main():
    MODELS_DIR = 'new_iteration/data_prep_iterative_models'
    PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
    CACHE_DIR = 'new_iteration/cache'
    TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
    FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
    TF_RECORD_DIR = 'new_iteration/training_data'
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
    JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
    CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
    TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
    
    # Decide whether to load the **custom** model or base DistilBERT (Base used for first iteration).
    # True for custom, False for base DistilBERT.
    LOAD_CUSTOM_MODEL = True
    NUM_NEG_SAMPLES = 10
    
    # Ensure output directories exist
    os.makedirs(MODELS_DIR, exist_ok=True)
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR, exist_ok=True)
    os.makedirs(TOKENIZER_DIR, exist_ok=True)
    os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
    os.makedirs(TF_RECORD_DIR, exist_ok=True)
    
    # Init config
    config_json = Path(MODELS_DIR) / "config.json"
    if config_json.exists():
        with open(config_json, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
        logger.info(f"Loaded ChatbotConfig from {config_json}")
    else:
        config = ChatbotConfig()
        logger.warning("No config.json found. Using default ChatbotConfig.")
    
    # Ensure negative samples are set
    config.neg_samples = NUM_NEG_SAMPLES
    
    # Load or init tokenizer
    try:
        if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
            logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
            tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
        else:
            logger.info(f"Loading base tokenizer for {config.pretrained_model}")
            tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
            
            Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
            tokenizer.save_pretrained(TOKENIZER_DIR)
            logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
    except Exception as e:
        logger.error(f"Failed to load or create tokenizer: {e}")
        sys.exit(1)
    
    # Init the encoder
    try:
        encoder = EncoderModel(config=config)
        logger.info("EncoderModel initialized successfully.")
        
        if LOAD_CUSTOM_MODEL:
            # Load the DistilBERT submodule from 'shared_encoder'
            shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
            if shared_encoder_path.exists():
                logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
                encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
            else:
                logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")

            # Load custom .weights.h5 (projection, dropout, etc.)
            custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
            if custom_weights_path.exists():
                logger.info(f"Loading custom top-level weights from {custom_weights_path}")
                
                # Dummy forward pass forces model build to ensure all layers are built
                dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
                _ = encoder(dummy_input, training=False)
                
                encoder.load_weights(str(custom_weights_path))
                logger.info("Custom encoder weights loaded successfully.")
            else:
                logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
        else:
            # Base DistilBERT with special tokens
            logger.info("Using the base DistilBERT without loading custom weights.")
        
        # Resize token embeddings in case we added special tokens (EncoderModel class)
        encoder.pretrained.resize_token_embeddings(len(tokenizer))
        logger.info(f"Token embeddings resized to: {len(tokenizer)}")

    except Exception as e:
        logger.error(f"Failed to initialize EncoderModel: {e}")
        sys.exit(1)
    
    # Load JSON dialogues
    try:
        if not Path(JSON_TRAINING_DATA_PATH).exists():
            logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
            dialogues = []
        else:
            dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
            logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
    except Exception as e:
        logger.error(f"Failed to load dialogues: {e}")
        sys.exit(1)
    
    # Load or init query_embeddings_cache. NOTE: recompute after each training. This was a bug source.
    query_embeddings_cache = {}
    if os.path.exists(CACHE_FILE):
        try:
            with open(CACHE_FILE, 'rb') as f:
                query_embeddings_cache = pickle.load(f)
            logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
        except Exception as e:
            logger.warning(f"Failed to load query embeddings cache: {e}")
    else:
        logger.info("No existing query embeddings cache found. Starting fresh.")
    
    # Initialize TFDataPipeline
    try:
        # Load or init FAISS index
        if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
            logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
            faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
            logger.info("FAISS index loaded successfully.")
        else:
            logger.info("No existing FAISS index found. Initializing a new index.")
            dimension = config.embedding_dim  # Ensure this matches your encoder's output
            faiss_index = faiss.IndexFlatIP(dimension)  # Using Inner Product for cosine similarity
            logger.info(f"Initialized new FAISS index with dimension {dimension}.")

        # Init TFDataPipeline with the FAISS index
        data_pipeline = TFDataPipeline(
            config=config,
            tokenizer=tokenizer,
            encoder=encoder,
            index_file_path=FAISS_INDEX_PRODUCTION_PATH,
            response_pool=[],
            max_length=config.max_context_token_limit,
            neg_samples=config.neg_samples,
            query_embeddings_cache=query_embeddings_cache,
            index_type='IndexFlatIP',
            nlist=100, # Not used for IndexFlatIP. Retained for future use of IndexIVFFlat
            max_retries=config.max_retries
        )
        logger.info("TFDataPipeline initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize TFDataPipeline: {e}")
        sys.exit(1)
    
    # Collect response pool from dialogues
    try:
        if dialogues:
            response_pool = data_pipeline.collect_responses_with_domain(dialogues)
            data_pipeline.response_pool = response_pool
            logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
        else:
            logger.warning("No dialogues loaded. response_pool remains empty.")
    except Exception as e:
        logger.error(f"Failed to collect responses: {e}")
        sys.exit(1)
    
    # Build FAISS index with response embeddings
    try:
        if data_pipeline.response_pool:
            data_pipeline.build_text_to_domain_map()
            logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
            data_pipeline.compute_and_index_response_embeddings()
            logger.info("Response embeddings computed and added to FAISS index.")
            
            # Save the FAISS index
            data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
            
            # Also save response pool JSON
            response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
            with open(response_pool_path, 'w', encoding='utf-8') as f:
                json.dump(data_pipeline.response_pool, f, indent=2)
            logger.info(f"Response pool saved to {response_pool_path}.")
        else:
            logger.warning("No responses to embed. Skipping FAISS indexing.")
    
    except Exception as e:
        logger.error(f"Failed to compute or add response embeddings: {e}")
        sys.exit(1)
    
    # Prepare training data as TFRecords (TensforFlow Record format)
    try:
        if dialogues:
            logger.info("Starting data preparation and saving as TFRecord...")
            data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
            logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
        else:
            logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
    except Exception as e:
        logger.error(f"Failed during data preparation and saving: {e}")
        sys.exit(1)
    
    # Save query embeddings cache
    try:
        with open(CACHE_FILE, 'wb') as f:
            pickle.dump(data_pipeline.query_embeddings_cache, f)
        logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
    except Exception as e:
        logger.error(f"Failed to save query embeddings cache: {e}")
        sys.exit(1)
    
    # Save Tokenizer
    try:
        tokenizer.save_pretrained(TOKENIZER_DIR)
        logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
    except Exception as e:
        logger.error(f"Failed to save tokenizer: {e}")
        sys.exit(1)
    
    logger.info("Data preparation pipeline completed successfully.")


if __name__ == "__main__":
    main()