File size: 2,965 Bytes
9decf80 f7b283c fc5f33b f7b283c 9b5daff f7b283c 9b5daff 5b413d1 9b5daff f7b283c 9decf80 f7b283c 9b5daff f5346f7 fc5f33b f7b283c 5b413d1 9b5daff 5b413d1 f7b283c 9b5daff d53c64b 5b413d1 f5346f7 f7b283c 9decf80 d53c64b f7b283c fc5f33b f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from plotter import Plotter
from logger_config import config_logger
logger = config_logger(__name__)
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
def parse_example(example_proto):
feature_description = {
'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
}
return tf.io.parse_single_example(example_proto, feature_description)
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
dataset = dataset.map(parse_example)
for i, example in enumerate(dataset.take(num_examples)):
print(f"Example {i+1}:")
print(f"Query IDs: {example['query_ids'].numpy()}")
print(f"Positive IDs: {example['positive_ids'].numpy()}")
print(f"Negative IDs: {example['negative_ids'].numpy()}")
print("-" * 50)
def main():
# Quick test to inspect TFRecord
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
# Initialize environment
tf.keras.backend.clear_session()
env = EnvironmentSetup()
env.initialize()
# Training configuration
EPOCHS = 20
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
CHECKPOINT_DIR = 'checkpoints/'
# Optimize batch size for Colab
batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
# Initialize config
config = ChatbotConfig()
# Initialize chatbot
chatbot = RetrievalChatbot(config, mode='training')
# Check for existing checkpoint and get initial epoch
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
initial_epoch = 0
if latest_checkpoint:
try:
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
initial_epoch = ckpt_number
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
except (IndexError, ValueError):
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
initial_epoch = 0
# Train the model
chatbot.train_model(
tfrecord_file_path=TF_RECORD_FILE_PATH,
epochs=EPOCHS,
batch_size=batch_size,
use_lr_schedule=True,
test_mode=True,
checkpoint_dir=CHECKPOINT_DIR,
initial_epoch=initial_epoch
)
# Save final model
model_save_path = env.training_dirs['base'] / 'final_model'
chatbot.save_models(model_save_path)
# Plot and save training history
plotter = Plotter(save_dir=env.training_dirs['plots'])
plotter.plot_training_history(chatbot.history)
if __name__ == "__main__":
main() |