JoeArmani
restructuring
71ca212
raw
history blame
10.5 kB
import os
import sys
import faiss
import json
import pickle
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tqdm.auto import tqdm
from pathlib import Path
from chatbot_model import ChatbotConfig, EncoderModel
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def main():
MODELS_DIR = 'new_iteration/data_prep_iterative_models'
PROCESSED_DATA_DIR = 'new_iteration/processed_outputs'
CACHE_DIR = 'new_iteration/cache'
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'new_iteration/training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_dialogues.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord')
# Decide whether to load the **custom** model or base DistilBERT (Base used for first iteration).
# True for custom, False for base DistilBERT.
LOAD_CUSTOM_MODEL = True
NUM_NEG_SAMPLES = 10
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Init config
config_json = Path(MODELS_DIR) / "config.json"
if config_json.exists():
with open(config_json, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_json}")
else:
config = ChatbotConfig()
logger.warning("No config.json found. Using default ChatbotConfig.")
# Ensure negative samples are set
config.neg_samples = NUM_NEG_SAMPLES
# Load or init tokenizer
try:
if Path(TOKENIZER_DIR).exists() and list(Path(TOKENIZER_DIR).iterdir()):
logger.info(f"Loading tokenizer from {TOKENIZER_DIR}")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
else:
logger.info(f"Loading base tokenizer for {config.pretrained_model}")
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
Path(TOKENIZER_DIR).mkdir(parents=True, exist_ok=True)
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"New tokenizer saved to {TOKENIZER_DIR}")
except Exception as e:
logger.error(f"Failed to load or create tokenizer: {e}")
sys.exit(1)
# Init the encoder
try:
encoder = EncoderModel(config=config)
logger.info("EncoderModel initialized successfully.")
if LOAD_CUSTOM_MODEL:
# Load the DistilBERT submodule from 'shared_encoder'
shared_encoder_path = Path(MODELS_DIR) / "shared_encoder"
if shared_encoder_path.exists():
logger.info(f"Loading DistilBERT submodule from {shared_encoder_path}")
encoder.pretrained = TFAutoModel.from_pretrained(shared_encoder_path)
else:
logger.warning(f"No shared_encoder found at {shared_encoder_path}, using base DistilBERT instead.")
# Load custom .weights.h5 (projection, dropout, etc.)
custom_weights_path = Path(MODELS_DIR) / "encoder_custom_weights.weights.h5"
if custom_weights_path.exists():
logger.info(f"Loading custom top-level weights from {custom_weights_path}")
# Dummy forward pass forces model build to ensure all layers are built
dummy_input = tf.zeros((1, config.max_context_token_limit), dtype=tf.int32)
_ = encoder(dummy_input, training=False)
encoder.load_weights(str(custom_weights_path))
logger.info("Custom encoder weights loaded successfully.")
else:
logger.warning(f"Custom weights file not found at {custom_weights_path}. Using only submodule weights.")
else:
# Base DistilBERT with special tokens
logger.info("Using the base DistilBERT without loading custom weights.")
# Resize token embeddings in case we added special tokens (EncoderModel class)
encoder.pretrained.resize_token_embeddings(len(tokenizer))
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
except Exception as e:
logger.error(f"Failed to initialize EncoderModel: {e}")
sys.exit(1)
# Load JSON dialogues
try:
if not Path(JSON_TRAINING_DATA_PATH).exists():
logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}, skipping.")
dialogues = []
else:
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, debug_samples=None)
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
except Exception as e:
logger.error(f"Failed to load dialogues: {e}")
sys.exit(1)
# Load or init query_embeddings_cache. NOTE: recompute after each training. This was a bug source.
query_embeddings_cache = {}
if os.path.exists(CACHE_FILE):
try:
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
except Exception as e:
logger.warning(f"Failed to load query embeddings cache: {e}")
else:
logger.info("No existing query embeddings cache found. Starting fresh.")
# Initialize TFDataPipeline
try:
# Load or init FAISS index
if Path(FAISS_INDEX_PRODUCTION_PATH).exists():
logger.info(f"Loading existing FAISS index from {FAISS_INDEX_PRODUCTION_PATH}...")
faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH)
logger.info("FAISS index loaded successfully.")
else:
logger.info("No existing FAISS index found. Initializing a new index.")
dimension = config.embedding_dim # Ensure this matches your encoder's output
faiss_index = faiss.IndexFlatIP(dimension) # Using Inner Product for cosine similarity
logger.info(f"Initialized new FAISS index with dimension {dimension}.")
# Init TFDataPipeline with the FAISS index
data_pipeline = TFDataPipeline(
config=config,
tokenizer=tokenizer,
encoder=encoder,
index_file_path=FAISS_INDEX_PRODUCTION_PATH,
response_pool=[],
max_length=config.max_context_token_limit,
neg_samples=config.neg_samples,
query_embeddings_cache=query_embeddings_cache,
index_type='IndexFlatIP',
nlist=100, # Not used for IndexFlatIP. Retained for future use of IndexIVFFlat
max_retries=config.max_retries
)
logger.info("TFDataPipeline initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize TFDataPipeline: {e}")
sys.exit(1)
# Collect response pool from dialogues
try:
if dialogues:
response_pool = data_pipeline.collect_responses_with_domain(dialogues)
data_pipeline.response_pool = response_pool
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
else:
logger.warning("No dialogues loaded. response_pool remains empty.")
except Exception as e:
logger.error(f"Failed to collect responses: {e}")
sys.exit(1)
# Build FAISS index with response embeddings
try:
if data_pipeline.response_pool:
data_pipeline.build_text_to_domain_map()
logger.info("Computing and adding response embeddings to FAISS index using TFDataPipeline...")
data_pipeline.compute_and_index_response_embeddings()
logger.info("Response embeddings computed and added to FAISS index.")
# Save the FAISS index
data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH)
# Also save response pool JSON
response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
with open(response_pool_path, 'w', encoding='utf-8') as f:
json.dump(data_pipeline.response_pool, f, indent=2)
logger.info(f"Response pool saved to {response_pool_path}.")
else:
logger.warning("No responses to embed. Skipping FAISS indexing.")
except Exception as e:
logger.error(f"Failed to compute or add response embeddings: {e}")
sys.exit(1)
# Prepare training data as TFRecords (TensforFlow Record format)
try:
if dialogues:
logger.info("Starting data preparation and saving as TFRecord...")
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
else:
logger.warning("No dialogues to build TFRecord from. Skipping TFRecord creation.")
except Exception as e:
logger.error(f"Failed during data preparation and saving: {e}")
sys.exit(1)
# Save query embeddings cache
try:
with open(CACHE_FILE, 'wb') as f:
pickle.dump(data_pipeline.query_embeddings_cache, f)
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
except Exception as e:
logger.error(f"Failed to save query embeddings cache: {e}")
sys.exit(1)
# Save Tokenizer
try:
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
except Exception as e:
logger.error(f"Failed to save tokenizer: {e}")
sys.exit(1)
logger.info("Data preparation pipeline completed successfully.")
if __name__ == "__main__":
main()