import tensorflow as tf from chatbot_model import RetrievalChatbot, ChatbotConfig from environment_setup import EnvironmentSetup from response_quality_checker import ResponseQualityChecker from chatbot_validator import ChatbotValidator from training_plotter import TrainingPlotter # Configure logging from logger_config import config_logger logger = config_logger(__name__) def run_interactive_chat(chatbot, quality_checker): """Separate function for interactive chat loop""" while True: user_input = input("You: ") if user_input.lower() in ['quit', 'exit', 'bye']: print("Assistant: Goodbye!") break response, candidates, metrics = chatbot.chat( query=user_input, conversation_history=None, quality_checker=quality_checker, top_k=5 ) print(f"Assistant: {response}") if metrics.get('is_confident', False): print("\nAlternative responses:") for resp, score in candidates[1:4]: print(f"Score: {score:.4f} - {resp}") def main(): # Initialize environment tf.keras.backend.clear_session() env = EnvironmentSetup() env.initialize() DEBUG_SAMPLES = 5 EPOCHS = 5 if DEBUG_SAMPLES else 20 TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json' # Optimize batch size for Colab batch_size = env.optimize_batch_size(base_batch_size=16) # Initialize configuration config = ChatbotConfig( embedding_dim=768, # DistilBERT max_context_token_limit=512, freeze_embeddings=False, ) # Load training data dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES) print(dialogues) # Initialize chatbot and verify FAISS index #with env.strategy.scope(): chatbot = RetrievalChatbot(config, dialogues) chatbot.build_models() chatbot.verify_faiss_index() chatbot.train_streaming( dialogues=dialogues, epochs=EPOCHS, batch_size=batch_size, use_lr_schedule=True, ) # Save final model model_save_path = env.training_dirs['base'] / 'final_model' chatbot.save_models(model_save_path) # Run automatic validation quality_checker = ResponseQualityChecker(chatbot=chatbot) validator = ChatbotValidator(chatbot, quality_checker) validation_metrics = validator.run_validation(num_examples=5) logger.info(f"Validation Metrics: {validation_metrics}") # Plot and save training history plotter = TrainingPlotter(save_dir=env.training_dirs['plots']) plotter.plot_training_history(chatbot.history) plotter.plot_validation_metrics(validation_metrics) # Run interactive chat logger.info("\nStarting interactive chat session...") run_interactive_chat(chatbot, quality_checker) if __name__ == "__main__": main()