File size: 7,161 Bytes
74af405
 
 
5b413d1
74af405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b5daff
74af405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b413d1
74af405
 
 
 
 
 
5b413d1
74af405
 
5b413d1
74af405
 
 
 
 
 
5b413d1
74af405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b413d1
 
74af405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b413d1
74af405
 
 
 
 
5b413d1
74af405
 
 
 
5b413d1
 
 
 
 
74af405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import sys
import faiss
import json
import pickle
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from chatbot_model import ChatbotConfig, EncoderModel
from environment_setup import EnvironmentSetup
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger

logger = config_logger(__name__)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

def cleanup_test_indices(faiss_dir, test_prefix='test_'):
    test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
    for file in test_files:
        file_path = os.path.join(faiss_dir, file)
        os.remove(file_path)
        logger.info(f"Removed test FAISS index file: {file_path}")

def main():
    # Constants
    MODELS_DIR = 'models'
    PROCESSED_DATA_DIR = 'processed_outputs'
    CACHE_DIR = 'cache'
    TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
    FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
    TF_RECORD_DIR = 'training_data'
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
    FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
    ENVIRONMENT = 'production'  # or 'test'
    if ENVIRONMENT == 'test':
        FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
    else:
        FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
    JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
    CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
    TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
    DEBUG_SAMPLES = None
    
    # Ensure output directories exist
    os.makedirs(MODELS_DIR, exist_ok=True)
    os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
    os.makedirs(CACHE_DIR, exist_ok=True)
    os.makedirs(TOKENIZER_DIR, exist_ok=True)
    os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
    os.makedirs(TF_RECORD_DIR, exist_ok=True)
    
    # Initialize configuration
    config = ChatbotConfig()
    logger.info(f"Chatbot Configuration: {config}")
    
    # Initialize tokenizer and add special tokens
    try:
        tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
        logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
        tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
        logger.info("Added special tokens to tokenizer.")
    except Exception as e:
        logger.error(f"Failed to load tokenizer: {e}")
        sys.exit(1)
    
    # Initialize encoder model and resize token embeddings
    try:
        encoder = EncoderModel(config=config)
        logger.info("EncoderModel initialized successfully.")
        encoder.pretrained.resize_token_embeddings(len(tokenizer))
        logger.info(f"Token embeddings resized to: {len(tokenizer)}")
    except Exception as e:
        logger.error(f"Failed to initialize EncoderModel: {e}")
        sys.exit(1)
    
    # Load JSON dialogues
    try:
        dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
        logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
    except Exception as e:
        logger.error(f"Failed to load dialogues: {e}")
        sys.exit(1)
    
    # Load or initialize query_embeddings_cache
    try:
        if os.path.exists(CACHE_FILE):
            with open(CACHE_FILE, 'rb') as f:
                query_embeddings_cache = pickle.load(f)
            logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
        else:
            query_embeddings_cache = {}
            logger.info("Initialized empty query embeddings cache.")
    except Exception as e:
        logger.error(f"Failed to load or initialize query embeddings cache: {e}")
        sys.exit(1)
    
    # Initialize TFDataPipeline
    try:
        data_pipeline = TFDataPipeline(
            config=config,
            tokenizer=tokenizer,
            encoder=encoder,
            index_file_path=FAISS_INDEX_PATH,
            response_pool=[],
            max_length=config.max_context_token_limit,
            neg_samples=config.neg_samples,
            query_embeddings_cache=query_embeddings_cache,
            index_type='IndexFlatIP',
            nlist=100,
            max_retries=config.max_retries
        )
        logger.info("TFDataPipeline initialized successfully.")
    except Exception as e:
        logger.error(f"Failed to initialize TFDataPipeline: {e}")
        sys.exit(1)
    
    # Collect unique assistant responses from dialogues
    try:
        response_pool = data_pipeline.collect_responses(dialogues)
        data_pipeline.response_pool = response_pool
        logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
    except Exception as e:
        logger.error(f"Failed to collect responses: {e}")
        sys.exit(1)
    
    # Compute and add response embeddings to FAISS index
    try:
        logger.info("Computing and adding response embeddings to FAISS index...")
        data_pipeline.compute_and_index_response_embeddings()
        logger.info("Response embeddings computed and added to FAISS index.")
    except Exception as e:
        logger.error(f"Failed to compute or add response embeddings: {e}")
        sys.exit(1)
    
    # Save FAISS index and response pool
    try:
        logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
        faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
        logger.info("FAISS index saved successfully.")
        
        response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
        with open(response_pool_path, 'w', encoding='utf-8') as f:
            json.dump(data_pipeline.response_pool, f, indent=2)
        logger.info(f"Response pool saved to {response_pool_path}.")
    except Exception as e:
        logger.error(f"Failed to save FAISS index: {e}")
        sys.exit(1)
    
    # Prepare and save training data as TFRecords
    try:
        logger.info("Starting data preparation and saving as TFRecord...")
        data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
        logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
    except Exception as e:
        logger.error(f"Failed during data preparation and saving: {e}")
        sys.exit(1)
    
    # Save query embeddings cache
    try:
        with open(CACHE_FILE, 'wb') as f:
            pickle.dump(data_pipeline.query_embeddings_cache, f)
        logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
    except Exception as e:
        logger.error(f"Failed to save query embeddings cache: {e}")
        sys.exit(1)
    
    # Save Tokenizer (including special tokens)
    try:
        tokenizer.save_pretrained(TOKENIZER_DIR)
        logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
    except Exception as e:
        logger.error(f"Failed to save tokenizer: {e}")
        sys.exit(1)
    
    logger.info("Data preparation pipeline completed successfully.")
    
if __name__ == "__main__":
    main()