File size: 3,882 Bytes
9decf80 f7b283c 9b5daff f7b283c 9b5daff f7b283c 9decf80 f7b283c 9b5daff f5346f7 f7b283c 9b5daff f7b283c 2183656 f7b283c 9b5daff 9decf80 9b5daff 9decf80 9b5daff 9decf80 f5346f7 f7b283c 9decf80 f7b283c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from training_plotter import TrainingPlotter
# Configure logging
from logger_config import config_logger
logger = config_logger(__name__)
def run_interactive_chat(chatbot, quality_checker):
"""Separate function for interactive chat loop"""
while True:
user_input = input("You: ")
if user_input.lower() in ['quit', 'exit', 'bye']:
print("Assistant: Goodbye!")
break
response, candidates, metrics = chatbot.chat(
query=user_input,
conversation_history=None,
quality_checker=quality_checker,
top_k=5
)
print(f"Assistant: {response}")
if metrics.get('is_confident', False):
print("\nAlternative responses:")
for resp, score in candidates[1:4]:
print(f"Score: {score:.4f} - {resp}")
def inspect_tfrecord(tfrecord_file_path, num_examples=3):
def parse_example(example_proto):
feature_description = {
'query_ids': tf.io.FixedLenFeature([512], tf.int64), # Adjust max_length if different
'positive_ids': tf.io.FixedLenFeature([512], tf.int64),
'negative_ids': tf.io.FixedLenFeature([3 * 512], tf.int64), # Adjust neg_samples if different
}
return tf.io.parse_single_example(example_proto, feature_description)
dataset = tf.data.TFRecordDataset(tfrecord_file_path)
dataset = dataset.map(parse_example)
for i, example in enumerate(dataset.take(num_examples)):
print(f"Example {i+1}:")
print(f"Query IDs: {example['query_ids'].numpy()}")
print(f"Positive IDs: {example['positive_ids'].numpy()}")
print(f"Negative IDs: {example['negative_ids'].numpy()}")
print("-" * 50)
def main():
# Quick test to inspect TFRecord
#inspect_tfrecord('training_data/training_data.tfrecord', num_examples=3)
# Initialize environment
tf.keras.backend.clear_session()
env = EnvironmentSetup()
env.initialize()
# Training configuration
EPOCHS = 20
TF_RECORD_FILE_PATH = 'training_data/training_data.tfrecord'
# Optimize batch size for Colab
batch_size = env.optimize_batch_size(base_batch_size=16)
# Initialize configuration
config = ChatbotConfig(
embedding_dim=768, # DistilBERT
max_context_token_limit=512,
freeze_embeddings=False,
)
# Initialize chatbot
#with env.strategy.scope():
chatbot = RetrievalChatbot(config, mode='training')
chatbot.build_models()
if chatbot.mode == 'preparation':
chatbot.verify_faiss_index()
chatbot.train_streaming(
tfrecord_file_path=TF_RECORD_FILE_PATH,
epochs=EPOCHS,
batch_size=batch_size,
use_lr_schedule=True,
)
# Save final model
model_save_path = env.training_dirs['base'] / 'final_model'
chatbot.save_models(model_save_path)
# Run automatic validation
quality_checker = ResponseQualityChecker(chatbot=chatbot)
validator = ChatbotValidator(chatbot, quality_checker)
validation_metrics = validator.run_validation(num_examples=5)
logger.info(f"Validation Metrics: {validation_metrics}")
# Plot and save training history
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
plotter.plot_training_history(chatbot.history)
plotter.plot_validation_metrics(validation_metrics)
# Run interactive chat
logger.info("\nStarting interactive chat session...")
run_interactive_chat(chatbot, quality_checker)
if __name__ == "__main__":
main() |