File size: 2,974 Bytes
9decf80
f7b283c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9decf80
f7b283c
 
 
2183656
9decf80
f7b283c
 
 
 
 
 
 
2183656
f7b283c
 
 
 
 
 
2183656
f7b283c
 
9decf80
 
 
f7b283c
9decf80
 
 
f7b283c
 
9decf80
f7b283c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import tensorflow as tf
from chatbot_model import RetrievalChatbot, ChatbotConfig
from environment_setup import EnvironmentSetup
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from training_plotter import TrainingPlotter

# Configure logging
from logger_config import config_logger
logger = config_logger(__name__)

def run_interactive_chat(chatbot, quality_checker):
    """Separate function for interactive chat loop"""
    while True:
        user_input = input("You: ")
        if user_input.lower() in ['quit', 'exit', 'bye']:
            print("Assistant: Goodbye!")
            break
        
        response, candidates, metrics = chatbot.chat(
            query=user_input,
            conversation_history=None,
            quality_checker=quality_checker,
            top_k=5
        )
        
        print(f"Assistant: {response}")
        
        if metrics.get('is_confident', False):
            print("\nAlternative responses:")
            for resp, score in candidates[1:4]:
                print(f"Score: {score:.4f} - {resp}")

def main():
    # Initialize environment
    tf.keras.backend.clear_session()
    env = EnvironmentSetup()
    env.initialize()
    
    DEBUG_SAMPLES = 5
    EPOCHS = 5 if DEBUG_SAMPLES else 20
    TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json'
    
    # Optimize batch size for Colab
    batch_size = env.optimize_batch_size(base_batch_size=16)
    
    # Initialize configuration
    config = ChatbotConfig(
        embedding_dim=768, # DistilBERT
        max_context_token_limit=512,
        freeze_embeddings=False,
    )
    
    # Load training data
    dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES)
    print(dialogues)
    
    # Initialize chatbot and verify FAISS index
    #with env.strategy.scope():
    chatbot = RetrievalChatbot(config, dialogues)
    chatbot.build_models()
    chatbot.verify_faiss_index()

    chatbot.train_streaming(
        dialogues=dialogues,
        epochs=EPOCHS,
        batch_size=batch_size,
        use_lr_schedule=True,
    )
    
    # Save final model
    model_save_path = env.training_dirs['base'] / 'final_model'
    chatbot.save_models(model_save_path)
    
    # Run automatic validation
    quality_checker = ResponseQualityChecker(chatbot=chatbot)
    validator = ChatbotValidator(chatbot, quality_checker)
    validation_metrics = validator.run_validation(num_examples=5)
    logger.info(f"Validation Metrics: {validation_metrics}")
    
    # Plot and save training history
    plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
    plotter.plot_training_history(chatbot.history)
    plotter.plot_validation_metrics(validation_metrics)
    
    # Run interactive chat
    logger.info("\nStarting interactive chat session...")
    run_interactive_chat(chatbot, quality_checker)

if __name__ == "__main__":
    main()