File size: 10,634 Bytes
74af405
 
 
5b413d1
74af405
7a0020b
 
74af405
7a0020b
 
 
74af405
 
 
 
 
 
 
 
 
 
7a0020b
 
 
74af405
 
7a0020b
74af405
7a0020b
74af405
7a0020b
 
 
 
 
 
74af405
 
 
 
 
 
 
 
 
7a0020b
 
 
 
 
 
 
 
 
 
 
 
74af405
7a0020b
74af405
7a0020b
 
 
 
 
 
 
 
 
 
 
 
 
74af405
7a0020b
74af405
 
7a0020b
74af405
 
 
7a0020b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74af405
 
7a0020b
74af405
5b413d1
74af405
 
 
 
7a0020b
 
 
 
 
 
74af405
 
 
 
 
7a0020b
 
 
74af405
 
 
7a0020b
 
 
 
74af405
 
 
7a0020b
 
 
 
 
 
 
 
 
 
 
 
 
 
74af405
 
 
 
7a0020b
74af405
 
 
 
5b413d1
 
74af405
 
 
 
 
 
 
7a0020b
74af405
7a0020b
 
 
 
 
 
74af405
 
 
 
7a0020b
 
74af405
7a0020b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74af405
 
7a0020b
74af405
 
7a0020b
74af405
7a0020b
 
 
 
 
 
74af405
 
 
 
7a0020b
74af405
 
 
 
 
 
 
 
7a0020b
74af405
 
 
 
 
 
 
 
7a0020b
 
74af405
7a0020b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import os
import sys
import faiss
import json
import pickle
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tqdm.auto import tqdm
from pathlib import Path

# Your existing modules
from chatbot_model import ChatbotConfig, EncoderModel
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger

logger = config_logger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def main():
    # Constants
    MODELS_DIR = 'new_iteration/data_prep_iterative_models'
    PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
    CACHE_DIR = 'new_iteration/cache'
    TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
    FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
    TF_RECORD_DIR = 'new_iteration/training_data'
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
    JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
    CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
    TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
    
    # Decide whether to load the **custom** fine-tuned model or just base DistilBERT.
    # True for custom, False for base DistilBERT.
    LOAD_CUSTOM_MODEL = True  
    NUM_NEG_SAMPLES = 10
    
    # Ensure output directories exist
    os.makedirs(MODELS_DIR, exist_ok=True)
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR, exist_ok=True)
    os.makedirs(TOKENIZER_DIR, exist_ok=True)
    os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
    os.makedirs(TF_RECORD_DIR, exist_ok=True)
    
    # Initialize config
    config_json = Path(MODELS_DIR) / "config.json"
    if config_json.exists():
        with open(config_json, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
        logger.info(f"Loaded ChatbotConfig from {config_json}")
    else:
        config = ChatbotConfig()
        logger.warning("No config.json found. Using default ChatbotConfig.")
    
    config.neg_samples = NUM_NEG_SAMPLES
    
    # Load or initialize tokenizer
    try:
        # If the directory has a valid tokenizer
        if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
            logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
            tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
        else:
            # Initialize from base DistilBERT
            logger.info(f"Loading base tokenizer for {config.pretrained_model}")
            tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
            
            # Save to disk
            Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
            tokenizer.save_pretrained(TOKENIZER_DIR)
            logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
    except Exception as e:
        logger.error(f"Failed to load or create tokenizer: {e}")
        sys.exit(1)
    
    # Initialize the encoder
    try:
        encoder = EncoderModel(config=config)
        logger.info("EncoderModel initialized successfully.")
        
        if LOAD_CUSTOM_MODEL:
            # Load the DistilBERT submodule from 'shared_encoder'
            shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
            if shared_encoder_path.exists():
                logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
                encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
            else:
                logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")

            # Load top-level custom .weights.h5 (projection, dropout, etc.)
            custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
            if custom_weights_path.exists():
                logger.info(f"Loading custom top-level weights from {custom_weights_path}")
                # Build model layers with a dummy forward pass
                dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
                _ = encoder(dummy_input, training=False)
                encoder.load_weights(str(custom_weights_path))
                logger.info("Custom encoder weights loaded successfully.")
            else:
                logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
        else:
            # Just base DistilBERT with special tokens resized
            logger.info("Using the base DistilBERT without loading custom weights.")
        
        # Resize token embeddings in case we added special tokens
        encoder.pretrained.resize_token_embeddings(len(tokenizer))
        logger.info(f"Token embeddings resized to: {len(tokenizer)}")

    except Exception as e:
        logger.error(f"Failed to initialize EncoderModel: {e}")
        sys.exit(1)
    
    # Load JSON dialogues
    try:
        if not Path(JSON_TRAINING_DATA_PATH).exists():
            logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
            dialogues = []
        else:
            dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
            logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
    except Exception as e:
        logger.error(f"Failed to load dialogues: {e}")
        sys.exit(1)
    
    # Load or initialize query_embeddings_cache
    query_embeddings_cache = {}
    if os.path.exists(CACHE_FILE):
        try:
            with open(CACHE_FILE, 'rb') as f:
                query_embeddings_cache = pickle.load(f)
            logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
        except Exception as e:
            logger.warning(f"Failed to load query embeddings cache: {e}")
    else:
        logger.info("No existing query embeddings cache found. Starting fresh.")
    
    # Initialize TFDataPipeline
    try:
        # Determine if FAISS index should be loaded or initialized
        if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
            # Load existing index
            logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
            faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
            logger.info("FAISS index loaded successfully.")
        else:
            # Initialize a new FAISS index
            logger.info("No existing FAISS index found. Initializing a new index.")
            dimension = config.embedding_dim  # Ensure this matches your encoder's output
            faiss_index = faiss.IndexFlatIP(dimension)  # Using Inner Product for cosine similarity
            logger.info(f"Initialized new FAISS index with dimension {dimension}.")

        # Initialize TFDataPipeline with the FAISS index
        data_pipeline = TFDataPipeline(
            config=config,
            tokenizer=tokenizer,
            encoder=encoder,
            index_file_path=FAISS_INDEX_PRODUCTION_PATH,
            response_pool=[],
            max_length=config.max_context_token_limit,
            neg_samples=config.neg_samples,
            query_embeddings_cache=query_embeddings_cache,
            index_type='IndexFlatIP',
            nlist=100,
            max_retries=config.max_retries
        )
        logger.info("TFDataPipeline initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize TFDataPipeline: {e}")
        sys.exit(1)
    
    # 7) Collect unique assistant responses from dialogues
    try:
        if dialogues:
            response_pool = data_pipeline.collect_responses_with_domain(dialogues)
            data_pipeline.response_pool = response_pool
            logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
        else:
            logger.warning("No dialogues loaded. response_pool remains empty.")
    except Exception as e:
        logger.error(f"Failed to collect responses: {e}")
        sys.exit(1)
    
    # 8) Build the FAISS index with response embeddings
    #    Instead of manually computing embeddings, we use the pipeline method
    try:
        if data_pipeline.response_pool:
            data_pipeline.build_text_to_domain_map()
            logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
            data_pipeline.compute_and_index_response_embeddings()
            logger.info("Response embeddings computed and added to FAISS index.")
            
            # Save the updated FAISS index
            data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
            
            # Also save the response pool JSON
            response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
            with open(response_pool_path, 'w', encoding='utf-8') as f:
                json.dump(data_pipeline.response_pool, f, indent=2)
            logger.info(f"Response pool saved to {response_pool_path}.")
        else:
            logger.warning("No responses to embed. Skipping FAISS indexing.")
    
    except Exception as e:
        logger.error(f"Failed to compute or add response embeddings: {e}")
        sys.exit(1)
    
    # 9) Prepare and save training data as TFRecords
    try:
        if dialogues:
            logger.info("Starting data preparation and saving as TFRecord...")
            data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
            logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
        else:
            logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
    except Exception as e:
        logger.error(f"Failed during data preparation and saving: {e}")
        sys.exit(1)
    
    # 10) Save query embeddings cache
    try:
        with open(CACHE_FILE, 'wb') as f:
            pickle.dump(data_pipeline.query_embeddings_cache, f)
        logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
    except Exception as e:
        logger.error(f"Failed to save query embeddings cache: {e}")
        sys.exit(1)
    
    # Save Tokenizer
    try:
        tokenizer.save_pretrained(TOKENIZER_DIR)
        logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
    except Exception as e:
        logger.error(f"Failed to save tokenizer: {e}")
        sys.exit(1)
    
    logger.info("Data preparation pipeline completed successfully.")


if __name__ == "__main__":
    main()