File size: 2,744 Bytes
9decf80
f7b283c
 
 
 
 
 
 
9b5daff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7b283c
9b5daff
 
5b413d1
9b5daff
f7b283c
9decf80
f7b283c
 
 
9b5daff
 
f5346f7
f7b283c
 
5b413d1
9b5daff
5b413d1
 
f7b283c
9b5daff
 
 
5b413d1
 
 
 
 
 
 
 
 
 
 
f5346f7
f7b283c
 
9decf80
5b413d1
 
f7b283c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from training_plotter import TrainingPlotter

from logger_config import config_logger
logger = config_logger(__name__)

def inspect_tfrecord(tfrecord_file_path, num_examples=3):
    def parse_example(example_proto):
        feature_description = {
            'query_ids': tf.io.FixedLenFeature([512], tf.int64),  # Adjust max_length if different
            'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
            'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64),  # Adjust neg_samples if different
        }
        return tf.io.parse_single_example(example_proto, feature_description)

    dataset = tf.data.TFRecordDataset(tfrecord_file_path)
    dataset = dataset.map(parse_example)

    for i, example in enumerate(dataset.take(num_examples)):
        print(f"Example {i+1}:")
        print(f"Query IDs: {example['query_ids'].numpy()}")
        print(f"Positive IDs: {example['positive_ids'].numpy()}")
        print(f"Negative IDs: {example['negative_ids'].numpy()}")
        print("-" * 50)
        
def main():
    
    # Quick test to inspect TFRecord
    # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
    
    # Initialize environment
    tf.keras.backend.clear_session()
    env = EnvironmentSetup()
    env.initialize()
    
    # Training configuration
    EPOCHS = 20
    TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
    
    # Optimize batch size for Colab
    batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
    
    # Initialize config
    config = ChatbotConfig()
    
    # Initialize chatbot
    chatbot = RetrievalChatbot(config, mode='training')
    
    # Load from a checkpoint
    checkpoint_dir = 'checkpoints/'
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    initial_epoch = 0
    if latest_checkpoint:
        ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
        initial_epoch = ckpt_number
        logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
    
    # Train the model
    chatbot.train_model(
        tfrecord_file_path=TF_RECORD_FILE_PATH,
        epochs=EPOCHS,
        batch_size=batch_size,
        use_lr_schedule=True,
        test_mode=False,
        initial_epoch=initial_epoch
    )
    
    # Save final model
    model_save_path = env.training_dirs['base'] / 'final_model'
    chatbot.save_models(model_save_path)
    
    # Plot and save training history
    plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
    plotter.plot_training_history(chatbot.history)

if __name__ == "__main__":
    main()