csc525_retrieval_based_chatbot / run_data_preparer.py
JoeArmani
improve split processes
9b5daff
raw
history blame
7.06 kB
import os
import sys
import faiss
import pickle
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from chatbot_model import ChatbotConfig, EncoderModel
from environment_setup import EnvironmentSetup
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def cleanup_test_indices(faiss_dir, test_prefix='test_'):
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
for file in test_files:
file_path = os.path.join(faiss_dir, file)
os.remove(file_path)
logger.info(f"Removed test FAISS index file: {file_path}")
def main():
# Constants
MODELS_DIR = 'models'
PROCESSED_DATA_DIR = 'processed_outputs'
CACHE_DIR = 'cache'
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
ENVIRONMENT = 'production' # or 'test'
if ENVIRONMENT == 'test':
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
else:
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
DEBUG_SAMPLES = None
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Initialize configuration
config = ChatbotConfig()
logger.info(f"Chatbot Configuration: {config}")
# Initialize tokenizer
try:
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
sys.exit(1)
# Add special tokens
try:
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
logger.info("Added special tokens to tokenizer.")
except Exception as e:
logger.error(f"Failed to add special tokens: {e}")
sys.exit(1)
# Initialize encoder model
try:
encoder = EncoderModel(config=config)
logger.info("EncoderModel initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize EncoderModel: {e}")
sys.exit(1)
# Resize token embeddings in encoder to match tokenizer
try:
encoder.pretrained.resize_token_embeddings(len(tokenizer))
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
except Exception as e:
logger.error(f"Failed to resize token embeddings: {e}")
sys.exit(1)
# Load JSON dialogues
try:
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
except Exception as e:
logger.error(f"Failed to load dialogues: {e}")
sys.exit(1)
# Load or initialize query_embeddings_cache
try:
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
else:
query_embeddings_cache = {}
logger.info("Initialized empty query embeddings cache.")
except Exception as e:
logger.error(f"Failed to load or initialize query embeddings cache: {e}")
sys.exit(1)
# Initialize TFDataPipeline
try:
data_pipeline = TFDataPipeline(
config=config,
tokenizer=tokenizer,
encoder=encoder,
index_file_path=FAISS_INDEX_PATH,
response_pool=[],
max_length=config.max_context_token_limit,
neg_samples=config.neg_samples,
query_embeddings_cache=query_embeddings_cache,
max_retries=config.max_retries
)
logger.info("TFDataPipeline initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize TFDataPipeline: {e}")
sys.exit(1)
# Collect unique assistant responses from dialogues
try:
response_pool = data_pipeline.collect_responses(dialogues)
data_pipeline.response_pool = response_pool
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
except Exception as e:
logger.error(f"Failed to collect responses: {e}")
sys.exit(1)
# Compute and add response embeddings to FAISS index
try:
logger.info("Computing and adding response embeddings to FAISS index...")
data_pipeline._compute_and_index_response_embeddings()
logger.info("Response embeddings computed and added to FAISS index.")
except Exception as e:
logger.error(f"Failed to compute or add response embeddings: {e}")
sys.exit(1)
# Save FAISS index
try:
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
logger.info("FAISS index saved successfully.")
except Exception as e:
logger.error(f"Failed to save FAISS index: {e}")
sys.exit(1)
# Prepare and save training data as TFRecords
try:
logger.info("Starting data preparation and saving as TFRecord...")
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
except Exception as e:
logger.error(f"Failed during data preparation and saving: {e}")
sys.exit(1)
# Save query embeddings cache
try:
with open(CACHE_FILE, 'wb') as f:
pickle.dump(data_pipeline.query_embeddings_cache, f)
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
except Exception as e:
logger.error(f"Failed to save query embeddings cache: {e}")
sys.exit(1)
# Save Tokenizer (including special tokens)
try:
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
except Exception as e:
logger.error(f"Failed to save tokenizer: {e}")
sys.exit(1)
logger.info("Data preparation pipeline completed successfully.")
if __name__ == "__main__":
main()