|
import tensorflow as tf |
|
from chatbot_model import RetrievalChatbot, ChatbotConfig |
|
from environment_setup import EnvironmentSetup |
|
from plotter import Plotter |
|
|
|
from logger_config import config_logger |
|
logger = config_logger(__name__) |
|
|
|
def inspect_tfrecord(tfrecord_file_path, num_examples=3): |
|
def parse_example(example_proto): |
|
feature_description = { |
|
'query_ids': tf.io.FixedLenFeature([512], tf.int64), |
|
'positive_ids': tf.io.FixedLenFeature([512], tf.int64), |
|
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), |
|
} |
|
return tf.io.parse_single_example(example_proto, feature_description) |
|
|
|
dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
|
dataset = dataset.map(parse_example) |
|
|
|
for i, example in enumerate(dataset.take(num_examples)): |
|
print(f"Example {i+1}:") |
|
print(f"Query IDs: {example['query_ids'].numpy()}") |
|
print(f"Positive IDs: {example['positive_ids'].numpy()}") |
|
print(f"Negative IDs: {example['negative_ids'].numpy()}") |
|
print("-" * 50) |
|
|
|
def main(): |
|
tf.keras.backend.clear_session() |
|
|
|
|
|
|
|
|
|
|
|
env = EnvironmentSetup() |
|
env.initialize() |
|
|
|
|
|
EPOCHS = 20 |
|
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord' |
|
CHECKPOINT_DIR = 'checkpoints/' |
|
|
|
batch_size = 32 |
|
|
|
|
|
config = ChatbotConfig() |
|
chatbot = RetrievalChatbot(config, mode='training') |
|
|
|
|
|
latest_checkpoint = tf.train.latest_checkpoint(CHECKPOINT_DIR) |
|
initial_epoch = 0 |
|
if latest_checkpoint: |
|
try: |
|
ckpt_number = int(latest_checkpoint.split('ckpt-')[-1]) |
|
initial_epoch = ckpt_number |
|
logger.info(f"Found checkpoint {latest_checkpoint}, resuming from epoch {initial_epoch}") |
|
except (IndexError, ValueError): |
|
logger.error(f"Failed to parse checkpoint number from {latest_checkpoint}") |
|
initial_epoch = 0 |
|
|
|
|
|
chatbot.train_model( |
|
tfrecord_file_path=TF_RECORD_FILE_PATH, |
|
epochs=EPOCHS, |
|
batch_size=batch_size, |
|
use_lr_schedule=True, |
|
test_mode=True, |
|
checkpoint_dir=CHECKPOINT_DIR, |
|
initial_epoch=initial_epoch |
|
) |
|
|
|
|
|
model_save_path = env.training_dirs['base'] / 'final_model' |
|
chatbot.save_models(model_save_path) |
|
|
|
|
|
plotter = Plotter(save_dir=env.training_dirs['plots']) |
|
plotter.plot_training_history(chatbot.history) |
|
|
|
if __name__ == "__main__": |
|
main() |