import os import sys import faiss import json import pickle from transformers import AutoTokenizer from tqdm.auto import tqdm from chatbot_model import ChatbotConfig, EncoderModel from environment_setup import EnvironmentSetup from tf_data_pipeline import TFDataPipeline from logger_config import config_logger logger = config_logger(__name__) os.environ["TOKENIZERS_PARALLELISM"] = "false" def cleanup_test_indices(faiss_dir, test_prefix='test_'): test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)] for file in test_files: file_path = os.path.join(faiss_dir, file) os.remove(file_path) logger.info(f"Removed test FAISS index file: {file_path}") def main(): # Constants MODELS_DIR = 'models' PROCESSED_DATA_DIR = 'processed_outputs' CACHE_DIR = 'cache' TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer') FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices') TF_RECORD_DIR = 'training_data' FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index') ENVIRONMENT = 'production' # or 'test' if ENVIRONMENT == 'test': FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH else: FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json') CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl') TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord') DEBUG_SAMPLES = None # Ensure output directories exist os.makedirs(MODELS_DIR, exist_ok=True) os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(TOKENIZER_DIR, exist_ok=True) os.makedirs(FAISS_INDICES_DIR, exist_ok=True) os.makedirs(TF_RECORD_DIR, exist_ok=True) # Initialize configuration config = ChatbotConfig() logger.info(f"Chatbot Configuration: {config}") # Initialize tokenizer and add special tokens try: tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model) logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.") tokenizer.add_special_tokens({'additional_special_tokens': ['']}) logger.info("Added special tokens to tokenizer.") except Exception as e: logger.error(f"Failed to load tokenizer: {e}") sys.exit(1) # Initialize encoder model and resize token embeddings try: encoder = EncoderModel(config=config) logger.info("EncoderModel initialized successfully.") encoder.pretrained.resize_token_embeddings(len(tokenizer)) logger.info(f"Token embeddings resized to: {len(tokenizer)}") except Exception as e: logger.error(f"Failed to initialize EncoderModel: {e}") sys.exit(1) # Load JSON dialogues try: dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES) logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.") except Exception as e: logger.error(f"Failed to load dialogues: {e}") sys.exit(1) # Load or initialize query_embeddings_cache try: if os.path.exists(CACHE_FILE): with open(CACHE_FILE, 'rb') as f: query_embeddings_cache = pickle.load(f) logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.") else: query_embeddings_cache = {} logger.info("Initialized empty query embeddings cache.") except Exception as e: logger.error(f"Failed to load or initialize query embeddings cache: {e}") sys.exit(1) # Initialize TFDataPipeline try: data_pipeline = TFDataPipeline( config=config, tokenizer=tokenizer, encoder=encoder, index_file_path=FAISS_INDEX_PATH, response_pool=[], max_length=config.max_context_token_limit, neg_samples=config.neg_samples, query_embeddings_cache=query_embeddings_cache, index_type='IndexFlatIP', nlist=100, max_retries=config.max_retries ) logger.info("TFDataPipeline initialized successfully.") except Exception as e: logger.error(f"Failed to initialize TFDataPipeline: {e}") sys.exit(1) # Collect unique assistant responses from dialogues try: response_pool = data_pipeline.collect_responses(dialogues) data_pipeline.response_pool = response_pool logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.") except Exception as e: logger.error(f"Failed to collect responses: {e}") sys.exit(1) # Compute and add response embeddings to FAISS index try: logger.info("Computing and adding response embeddings to FAISS index...") data_pipeline.compute_and_index_response_embeddings() logger.info("Response embeddings computed and added to FAISS index.") except Exception as e: logger.error(f"Failed to compute or add response embeddings: {e}") sys.exit(1) # Save FAISS index and response pool try: logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...") faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH) logger.info("FAISS index saved successfully.") response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json') with open(response_pool_path, 'w', encoding='utf-8') as f: json.dump(data_pipeline.response_pool, f, indent=2) logger.info(f"Response pool saved to {response_pool_path}.") except Exception as e: logger.error(f"Failed to save FAISS index: {e}") sys.exit(1) # Prepare and save training data as TFRecords try: logger.info("Starting data preparation and saving as TFRecord...") data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH) logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.") except Exception as e: logger.error(f"Failed during data preparation and saving: {e}") sys.exit(1) # Save query embeddings cache try: with open(CACHE_FILE, 'wb') as f: pickle.dump(data_pipeline.query_embeddings_cache, f) logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.") except Exception as e: logger.error(f"Failed to save query embeddings cache: {e}") sys.exit(1) # Save Tokenizer (including special tokens) try: tokenizer.save_pretrained(TOKENIZER_DIR) logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.") except Exception as e: logger.error(f"Failed to save tokenizer: {e}") sys.exit(1) logger.info("Data preparation pipeline completed successfully.") if __name__ == "__main__": main()