File size: 10,634 Bytes
74af405 5b413d1 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 5b413d1 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 5b413d1 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b 74af405 7a0020b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
import os
import sys
import faiss
import json
import pickle
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tqdm.auto import tqdm
from pathlib import Path
# Your existing modules
from chatbot_model import ChatbotConfig, EncoderModel
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
# Constants
MODELS_DIR = 'new_iteration/data_prep_iterative_models'
PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
CACHE_DIR = 'new_iteration/cache'
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'new_iteration/training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
# Decide whether to load the **custom** fine-tuned model or just base DistilBERT.
# True for custom, False for base DistilBERT.
LOAD_CUSTOM_MODEL = True
NUM_NEG_SAMPLES = 10
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Initialize config
config_json = Path(MODELS_DIR) / "config.json"
if config_json.exists():
with open(config_json, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_json}")
else:
config = ChatbotConfig()
logger.warning("No config.json found. Using default ChatbotConfig.")
config.neg_samples = NUM_NEG_SAMPLES
# Load or initialize tokenizer
try:
# If the directory has a valid tokenizer
if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
else:
# Initialize from base DistilBERT
logger.info(f"Loading base tokenizer for {config.pretrained_model}")
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
# Save to disk
Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
except Exception as e:
logger.error(f"Failed to load or create tokenizer: {e}")
sys.exit(1)
# Initialize the encoder
try:
encoder = EncoderModel(config=config)
logger.info("EncoderModel initialized successfully.")
if LOAD_CUSTOM_MODEL:
# Load the DistilBERT submodule from 'shared_encoder'
shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
if shared_encoder_path.exists():
logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
else:
logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
# Load top-level custom .weights.h5 (projection, dropout, etc.)
custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
if custom_weights_path.exists():
logger.info(f"Loading custom top-level weights from {custom_weights_path}")
# Build model layers with a dummy forward pass
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
_ = encoder(dummy_input, training=False)
encoder.load_weights(str(custom_weights_path))
logger.info("Custom encoder weights loaded successfully.")
else:
logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
else:
# Just base DistilBERT with special tokens resized
logger.info("Using the base DistilBERT without loading custom weights.")
# Resize token embeddings in case we added special tokens
encoder.pretrained.resize_token_embeddings(len(tokenizer))
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
except Exception as e:
logger.error(f"Failed to initialize EncoderModel: {e}")
sys.exit(1)
# Load JSON dialogues
try:
if not Path(JSON_TRAINING_DATA_PATH).exists():
logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
dialogues = []
else:
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
except Exception as e:
logger.error(f"Failed to load dialogues: {e}")
sys.exit(1)
# Load or initialize query_embeddings_cache
query_embeddings_cache = {}
if os.path.exists(CACHE_FILE):
try:
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
except Exception as e:
logger.warning(f"Failed to load query embeddings cache: {e}")
else:
logger.info("No existing query embeddings cache found. Starting fresh.")
# Initialize TFDataPipeline
try:
# Determine if FAISS index should be loaded or initialized
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
# Load existing index
logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
logger.info("FAISS index loaded successfully.")
else:
# Initialize a new FAISS index
logger.info("No existing FAISS index found. Initializing a new index.")
dimension = config.embedding_dim # Ensure this matches your encoder's output
faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
# Initialize TFDataPipeline with the FAISS index
data_pipeline = TFDataPipeline(
config=config,
tokenizer=tokenizer,
encoder=encoder,
index_file_path=FAISS_INDEX_PRODUCTION_PATH,
response_pool=[],
max_length=config.max_context_token_limit,
neg_samples=config.neg_samples,
query_embeddings_cache=query_embeddings_cache,
index_type='IndexFlatIP',
nlist=100,
max_retries=config.max_retries
)
logger.info("TFDataPipeline initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize TFDataPipeline: {e}")
sys.exit(1)
# 7) Collect unique assistant responses from dialogues
try:
if dialogues:
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
data_pipeline.response_pool = response_pool
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
else:
logger.warning("No dialogues loaded. response_pool remains empty.")
except Exception as e:
logger.error(f"Failed to collect responses: {e}")
sys.exit(1)
# 8) Build the FAISS index with response embeddings
# Instead of manually computing embeddings, we use the pipeline method
try:
if data_pipeline.response_pool:
data_pipeline.build_text_to_domain_map()
logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
data_pipeline.compute_and_index_response_embeddings()
logger.info("Response embeddings computed and added to FAISS index.")
# Save the updated FAISS index
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
# Also save the response pool JSON
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
with open(response_pool_path, 'w', encoding='utf-8') as f:
json.dump(data_pipeline.response_pool, f, indent=2)
logger.info(f"Response pool saved to {response_pool_path}.")
else:
logger.warning("No responses to embed. Skipping FAISS indexing.")
except Exception as e:
logger.error(f"Failed to compute or add response embeddings: {e}")
sys.exit(1)
# 9) Prepare and save training data as TFRecords
try:
if dialogues:
logger.info("Starting data preparation and saving as TFRecord...")
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
else:
logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
except Exception as e:
logger.error(f"Failed during data preparation and saving: {e}")
sys.exit(1)
# 10) Save query embeddings cache
try:
with open(CACHE_FILE, 'wb') as f:
pickle.dump(data_pipeline.query_embeddings_cache, f)
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
except Exception as e:
logger.error(f"Failed to save query embeddings cache: {e}")
sys.exit(1)
# Save Tokenizer
try:
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
except Exception as e:
logger.error(f"Failed to save tokenizer: {e}")
sys.exit(1)
logger.info("Data preparation pipeline completed successfully.")
if __name__ == "__main__":
main()
|