File size: 2,744 Bytes
9decf80 f7b283c 9b5daff f7b283c 9b5daff 5b413d1 9b5daff f7b283c 9decf80 f7b283c 9b5daff f5346f7 f7b283c 5b413d1 9b5daff 5b413d1 f7b283c 9b5daff 5b413d1 f5346f7 f7b283c 9decf80 5b413d1 f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from training_plotter import TrainingPlotter
from logger_config import config_logger
logger = config_logger(__name__)
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
def parse_example(example_proto):
feature_description = {
'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
}
return tf.io.parse_single_example(example_proto, feature_description)
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
dataset = dataset.map(parse_example)
for i, example in enumerate(dataset.take(num_examples)):
print(f"Example {i+1}:")
print(f"Query IDs: {example['query_ids'].numpy()}")
print(f"Positive IDs: {example['positive_ids'].numpy()}")
print(f"Negative IDs: {example['negative_ids'].numpy()}")
print("-" * 50)
def main():
# Quick test to inspect TFRecord
# inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
# Initialize environment
tf.keras.backend.clear_session()
env = EnvironmentSetup()
env.initialize()
# Training configuration
EPOCHS = 20
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
# Optimize batch size for Colab
batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
# Initialize config
config = ChatbotConfig()
# Initialize chatbot
chatbot = RetrievalChatbot(config, mode='training')
# Load from a checkpoint
checkpoint_dir = 'checkpoints/'
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint:
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
initial_epoch = ckpt_number
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
# Train the model
chatbot.train_model(
tfrecord_file_path=TF_RECORD_FILE_PATH,
epochs=EPOCHS,
batch_size=batch_size,
use_lr_schedule=True,
test_mode=False,
initial_epoch=initial_epoch
)
# Save final model
model_save_path = env.training_dirs['base'] / 'final_model'
chatbot.save_models(model_save_path)
# Plot and save training history
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
plotter.plot_training_history(chatbot.history)
if __name__ == "__main__":
main() |