JoeArmani
training and inference updates
5b413d1
raw
history blame
7.16 kB
import os
import sys
import faiss
import json
import pickle
from transformers import AutoTokenizer
from tqdm.auto import tqdm
from chatbot_model import ChatbotConfig, EncoderModel
from environment_setup import EnvironmentSetup
from tf_data_pipeline import TFDataPipeline
from logger_config import config_logger
logger = config_logger(__name__)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def cleanup_test_indices(faiss_dir, test_prefix='test_'):
test_files = [f for f in os.listdir(faiss_dir) if f.startswith(test_prefix)]
for file in test_files:
file_path = os.path.join(faiss_dir, file)
os.remove(file_path)
logger.info(f"Removed test FAISS index file: {file_path}")
def main():
# Constants
MODELS_DIR = 'models'
PROCESSED_DATA_DIR = 'processed_outputs'
CACHE_DIR = 'cache'
TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer')
FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices')
TF_RECORD_DIR = 'training_data'
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
ENVIRONMENT = 'production' # or 'test'
if ENVIRONMENT == 'test':
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
else:
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'augmented_dialogues.json')
CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl')
TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data.tfrecord')
DEBUG_SAMPLES = None
# Ensure output directories exist
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)
os.makedirs(FAISS_INDICES_DIR, exist_ok=True)
os.makedirs(TF_RECORD_DIR, exist_ok=True)
# Initialize configuration
config = ChatbotConfig()
logger.info(f"Chatbot Configuration: {config}")
# Initialize tokenizer and add special tokens
try:
tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model)
logger.info(f"Tokenizer '{config.pretrained_model}' loaded successfully.")
tokenizer.add_special_tokens({'additional_special_tokens': ['<EMPTY_NEGATIVE>']})
logger.info("Added special tokens to tokenizer.")
except Exception as e:
logger.error(f"Failed to load tokenizer: {e}")
sys.exit(1)
# Initialize encoder model and resize token embeddings
try:
encoder = EncoderModel(config=config)
logger.info("EncoderModel initialized successfully.")
encoder.pretrained.resize_token_embeddings(len(tokenizer))
logger.info(f"Token embeddings resized to: {len(tokenizer)}")
except Exception as e:
logger.error(f"Failed to initialize EncoderModel: {e}")
sys.exit(1)
# Load JSON dialogues
try:
dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH, DEBUG_SAMPLES)
logger.info(f"Loaded {len(dialogues)} dialogues from {JSON_TRAINING_DATA_PATH}.")
except Exception as e:
logger.error(f"Failed to load dialogues: {e}")
sys.exit(1)
# Load or initialize query_embeddings_cache
try:
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'rb') as f:
query_embeddings_cache = pickle.load(f)
logger.info(f"Loaded {len(query_embeddings_cache)} query embeddings from {CACHE_FILE}.")
else:
query_embeddings_cache = {}
logger.info("Initialized empty query embeddings cache.")
except Exception as e:
logger.error(f"Failed to load or initialize query embeddings cache: {e}")
sys.exit(1)
# Initialize TFDataPipeline
try:
data_pipeline = TFDataPipeline(
config=config,
tokenizer=tokenizer,
encoder=encoder,
index_file_path=FAISS_INDEX_PATH,
response_pool=[],
max_length=config.max_context_token_limit,
neg_samples=config.neg_samples,
query_embeddings_cache=query_embeddings_cache,
index_type='IndexFlatIP',
nlist=100,
max_retries=config.max_retries
)
logger.info("TFDataPipeline initialized successfully.")
except Exception as e:
logger.error(f"Failed to initialize TFDataPipeline: {e}")
sys.exit(1)
# Collect unique assistant responses from dialogues
try:
response_pool = data_pipeline.collect_responses(dialogues)
data_pipeline.response_pool = response_pool
logger.info(f"Collected {len(response_pool)} unique assistant responses from dialogues.")
except Exception as e:
logger.error(f"Failed to collect responses: {e}")
sys.exit(1)
# Compute and add response embeddings to FAISS index
try:
logger.info("Computing and adding response embeddings to FAISS index...")
data_pipeline.compute_and_index_response_embeddings()
logger.info("Response embeddings computed and added to FAISS index.")
except Exception as e:
logger.error(f"Failed to compute or add response embeddings: {e}")
sys.exit(1)
# Save FAISS index and response pool
try:
logger.info(f"Saving FAISS index to {FAISS_INDEX_PATH}...")
faiss.write_index(data_pipeline.index, FAISS_INDEX_PATH)
logger.info("FAISS index saved successfully.")
response_pool_path = FAISS_INDEX_PATH.replace('.index', '_responses.json')
with open(response_pool_path, 'w', encoding='utf-8') as f:
json.dump(data_pipeline.response_pool, f, indent=2)
logger.info(f"Response pool saved to {response_pool_path}.")
except Exception as e:
logger.error(f"Failed to save FAISS index: {e}")
sys.exit(1)
# Prepare and save training data as TFRecords
try:
logger.info("Starting data preparation and saving as TFRecord...")
data_pipeline.prepare_and_save_data(dialogues, TF_RECORD_PATH)
logger.info(f"Data saved as TFRecord at {TF_RECORD_PATH}.")
except Exception as e:
logger.error(f"Failed during data preparation and saving: {e}")
sys.exit(1)
# Save query embeddings cache
try:
with open(CACHE_FILE, 'wb') as f:
pickle.dump(data_pipeline.query_embeddings_cache, f)
logger.info(f"Saved {len(data_pipeline.query_embeddings_cache)} query embeddings to {CACHE_FILE}.")
except Exception as e:
logger.error(f"Failed to save query embeddings cache: {e}")
sys.exit(1)
# Save Tokenizer (including special tokens)
try:
tokenizer.save_pretrained(TOKENIZER_DIR)
logger.info(f"Tokenizer saved to {TOKENIZER_DIR}.")
except Exception as e:
logger.error(f"Failed to save tokenizer: {e}")
sys.exit(1)
logger.info("Data preparation pipeline completed successfully.")
if __name__ == "__main__":
main()