csc525_retrieval_based_chatbot / run_model_train.py
JoeArmani
FAISS and streaming updates
9decf80
raw
history blame
2.98 kB
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from training_plotter import TrainingPlotter
# Configure logging
from logger_config import config_logger
logger = config_logger(__name__)
def run_interactive_chat(chatbot, quality_checker):
"""Separate function for interactive chat loop"""
while True:
user_input = input("You: ")
if user_input.lower() in ['quit', 'exit', 'bye']:
print("Assistant: Goodbye!")
break
response, candidates, metrics = chatbot.chat(
query=user_input,
conversation_history=None,
quality_checker=quality_checker,
top_k=5
)
print(f"Assistant: {response}")
if metrics.get('is_confident', False):
print("\nAlternative responses:")
for resp, score in candidates[1:4]:
print(f"Score: {score:.4f} - {resp}")
def main():
# Initialize environment
tf.keras.backend.clear_session()
env = EnvironmentSetup()
env.initialize()
DEBUG_SAMPLES = 15
EPOCHS = 5 if DEBUG_SAMPLES else 20
TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
# Optimize batch size for Colab
batch_size = env.optimize_batch_size(base_batch_size=16)
# Initialize configuration
config = ChatbotConfig(
embedding_dim=512, # 768, # Match DistilBERT's dimension
max_context_token_limit=512,
freeze_embeddings=False,
)
# Load training data
dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
# Initialize chatbot and verify FAISS index
#with env.strategy.scope():
chatbot = RetrievalChatbot(config, dialogues)
chatbot.build_models()
chatbot.verify_faiss_index()
chatbot.train_streaming(
dialogues=dialogues,
epochs=EPOCHS,
batch_size=batch_size,
use_lr_schedule=True,
)
# Save final model
model_save_path = env.training_dirs['base'] / 'final_model'
chatbot.save_models(model_save_path)
# Run automatic validation
quality_checker = ResponseQualityChecker(chatbot=chatbot)
validator = ChatbotValidator(chatbot, quality_checker)
validation_metrics = validator.run_validation(num_examples=5)
logger.info(f"Validation Metrics: {validation_metrics}")
# Plot and save training history
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
plotter.plot_training_history(chatbot.history)
plotter.plot_validation_metrics(validation_metrics)
# Run interactive chat
logger.info("\nStarting interactive chat session...")
run_interactive_chat(chatbot, quality_checker)
if __name__ == "__main__":
main()