File size: 2,965 Bytes
9decf80
f7b283c
 
fc5f33b
f7b283c
 
 
 
9b5daff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7b283c
9b5daff
 
5b413d1
9b5daff
f7b283c
9decf80
f7b283c
 
 
9b5daff
 
f5346f7
fc5f33b
f7b283c
5b413d1
9b5daff
5b413d1
 
f7b283c
9b5daff
 
 
d53c64b
 
 
 
 
 
 
 
 
 
 
5b413d1
 
 
f5346f7
f7b283c
 
9decf80
d53c64b
 
 
f7b283c
 
 
 
 
 
 
fc5f33b
f7b283c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from plotter import Plotter

from logger_config import config_logger
logger = config_logger(__name__)

def inspect_tfrecord(tfrecord_file_path, num_examples=3):
    def parse_example(example_proto):
        feature_description = {
            'query_ids': tf.io.FixedLenFeature([512], tf.int64),  # Adjust max_length if different
            'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
            'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64),  # Adjust neg_samples if different
        }
        return tf.io.parse_single_example(example_proto, feature_description)

    dataset = tf.data.TFRecordDataset(tfrecord_file_path)
    dataset = dataset.map(parse_example)

    for i, example in enumerate(dataset.take(num_examples)):
        print(f"Example {i+1}:")
        print(f"Query IDs: {example['query_ids'].numpy()}")
        print(f"Positive IDs: {example['positive_ids'].numpy()}")
        print(f"Negative IDs: {example['negative_ids'].numpy()}")
        print("-" * 50)
        
def main():
    
    # Quick test to inspect TFRecord
    # inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
    
    # Initialize environment
    tf.keras.backend.clear_session()
    env = EnvironmentSetup()
    env.initialize()
    
    # Training configuration
    EPOCHS = 20
    TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
    CHECKPOINT_DIR = 'checkpoints/'
    # Optimize batch size for Colab
    batch_size = 32 # env.optimize_batch_size(base_batch_size=16)
    
    # Initialize config
    config = ChatbotConfig()
    
    # Initialize chatbot
    chatbot = RetrievalChatbot(config, mode='training')
    
    # Check for existing checkpoint and get initial epoch
    latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR)
    initial_epoch = 0
    if latest_checkpoint:
        try:
            ckpt_number = int(latest_checkpoint.split('ckpt-')[-1])
            initial_epoch = ckpt_number
            logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}")
        except (IndexError, ValueError):
            logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}")
            initial_epoch = 0
    
    # Train the model
    chatbot.train_model(
        tfrecord_file_path=TF_RECORD_FILE_PATH,
        epochs=EPOCHS,
        batch_size=batch_size,
        use_lr_schedule=True,
        test_mode=True,
        checkpoint_dir=CHECKPOINT_DIR,
        initial_epoch=initial_epoch
    )
    
    # Save final model
    model_save_path = env.training_dirs['base'] / 'final_model'
    chatbot.save_models(model_save_path)
    
    # Plot and save training history
    plotter = Plotter(save_dir=env.training_dirs['plots'])
    plotter.plot_training_history(chatbot.history)

if __name__ == "__main__":
    main()