File size: 7,161 Bytes
74af405 5b413d1 74af405 9b5daff 74af405 5b413d1 74af405 5b413d1 74af405 5b413d1 74af405 5b413d1 74af405 5b413d1 74af405 5b413d1 74af405 5b413d1 74af405 5b413d1 74af405 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
import sys
import faiss
import json
import pickle
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from chatbot_model import ChatbotConfig, EncoderModel
from environment_setup import EnvironmentSetup
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def cleanup_test_indices(faiss_dir, test_prefix='test_'):
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
for file in test_files:
file_path = os.path.join(faiss_dir, file)
os.remove(file_path)
logger.info(f"Removed test FAISS index file: {file_path}")
def main():
# Constants
MODELS_DIR = 'models'
PROCESSED_DATA_DIR = 'processed_outputs'
CACHE_DIR = 'cache'
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
ENVIRONMENT = 'production' # or 'test'
if ENVIRONMENT == 'test':
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
else:
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
DEBUG_SAMPLES = None
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Initialize configuration
config = ChatbotConfig()
logger.info(f"Chatbot Configuration: {config}")
# Initialize tokenizer and add special tokens
try:
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
logger.info("Added special tokens to tokenizer.")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
sys.exit(1)
# Initialize encoder model and resize token embeddings
try:
encoder = EncoderModel(config=config)
logger.info("EncoderModel initialized successfully.")
encoder.pretrained.resize_token_embeddings(len(tokenizer))
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
except Exception as e:
logger.error(f"Failed to initialize EncoderModel: {e}")
sys.exit(1)
# Load JSON dialogues
try:
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
except Exception as e:
logger.error(f"Failed to load dialogues: {e}")
sys.exit(1)
# Load or initialize query_embeddings_cache
try:
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
else:
query_embeddings_cache = {}
logger.info("Initialized empty query embeddings cache.")
except Exception as e:
logger.error(f"Failed to load or initialize query embeddings cache: {e}")
sys.exit(1)
# Initialize TFDataPipeline
try:
data_pipeline = TFDataPipeline(
config=config,
tokenizer=tokenizer,
encoder=encoder,
index_file_path=FAISS_INDEX_PATH,
response_pool=[],
max_length=config.max_context_token_limit,
neg_samples=config.neg_samples,
query_embeddings_cache=query_embeddings_cache,
index_type='IndexFlatIP',
nlist=100,
max_retries=config.max_retries
)
logger.info("TFDataPipeline initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize TFDataPipeline: {e}")
sys.exit(1)
# Collect unique assistant responses from dialogues
try:
response_pool = data_pipeline.collect_responses(dialogues)
data_pipeline.response_pool = response_pool
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
except Exception as e:
logger.error(f"Failed to collect responses: {e}")
sys.exit(1)
# Compute and add response embeddings to FAISS index
try:
logger.info("Computing and adding response embeddings to FAISS index...")
data_pipeline.compute_and_index_response_embeddings()
logger.info("Response embeddings computed and added to FAISS index.")
except Exception as e:
logger.error(f"Failed to compute or add response embeddings: {e}")
sys.exit(1)
# Save FAISS index and response pool
try:
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
logger.info("FAISS index saved successfully.")
response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
with open(response_pool_path, 'w', encoding='utf-8') as f:
json.dump(data_pipeline.response_pool, f, indent=2)
logger.info(f"Response pool saved to {response_pool_path}.")
except Exception as e:
logger.error(f"Failed to save FAISS index: {e}")
sys.exit(1)
# Prepare and save training data as TFRecords
try:
logger.info("Starting data preparation and saving as TFRecord...")
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
except Exception as e:
logger.error(f"Failed during data preparation and saving: {e}")
sys.exit(1)
# Save query embeddings cache
try:
with open(CACHE_FILE, 'wb') as f:
pickle.dump(data_pipeline.query_embeddings_cache, f)
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
except Exception as e:
logger.error(f"Failed to save query embeddings cache: {e}")
sys.exit(1)
# Save Tokenizer (including special tokens)
try:
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
except Exception as e:
logger.error(f"Failed to save tokenizer: {e}")
sys.exit(1)
logger.info("Data preparation pipeline completed successfully.")
if __name__ == "__main__":
main() |