File size: 4,454 Bytes
5b413d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import json
from chatbot_model import ChatbotConfig, RetrievalChatbot
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from training_plotter import TrainingPlotter
from environment_setup import EnvironmentSetup

from logger_config import config_logger
logger = config_logger(__name__)

def run_interactive_chat(chatbot, quality_checker):
    """Separate function for interactive chat loop"""
    while True:
        try:
            user_input = input("You: ")
        except (KeyboardInterrupt, EOFError):
            print("\nAssistant: Goodbye!")
            break

        if user_input.lower() in ['quit', 'exit', 'bye']:
            print("Assistant: Goodbye!")
            break

        response, candidates, metrics = chatbot.chat(
            query=user_input,
            conversation_history=None,
            quality_checker=quality_checker,
            top_k=5
        )

        print(f"Assistant: {response}")

        if metrics.get('is_confident', False):
            print("\nAlternative responses:")
            for resp, score in candidates[1:4]:
                print(f"Score: {score:.4f} - {resp}")
        else:
            print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")

# TODO: 
def validate_chatbot():
    # Initialize environment
    env = EnvironmentSetup()
    env.initialize()

    MODEL_DIR = 'models'
    FAISS_INDICES_DIR = os.path.join(MODEL_DIR, 'faiss_indices')
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
    FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
    RESPONSE_POOL_PRODUCTION_PATH = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
    RESPONSE_POOL_TEST_PATH = FAISS_INDEX_TEST_PATH.replace('.index', '_responses.json')
    ENVIRONMENT = 'production'  # or 'test'
    if ENVIRONMENT == 'test':
        FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
        RESPONSE_POOL_PATH = RESPONSE_POOL_TEST_PATH
    else:
        FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
        RESPONSE_POOL_PATH = RESPONSE_POOL_PRODUCTION_PATH

    # Load config
    config = ChatbotConfig()

    # Initialize RetrievalChatbot in 'inference' mode
    try:
        chatbot = RetrievalChatbot(config=config, mode='inference')
        logger.info("RetrievalChatbot initialized in 'inference' mode.")
    except Exception as e:
        logger.error(f"Failed to initialize RetrievalChatbot: {e}")
        return

    # Ensure FAISS index and response pool are accessible, then load
    if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
        logger.error("FAISS index or response pool file is missing.")
        return

    try:
        chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
        logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
        
        with open(RESPONSE_POOL_PATH, 'r', encoding='utf-8') as f:
            chatbot.data_pipeline.response_pool = json.load(f)
        logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
        
        chatbot.data_pipeline.validate_faiss_index()
        logger.info("FAISS index and response pool validated successfully.")
    except Exception as e:
        logger.error(f"Failed to load FAISS index: {e}")
        return

    # Initialize ResponseQualityChecker and ChatbotValidator
    quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
    validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
    logger.info("ResponseQualityChecker and ChatbotValidator initialized.")

    # Run validation
    try:
        validation_metrics = validator.run_validation(num_examples=5)
        logger.info(f"Validation Metrics: {validation_metrics}")
    except Exception as e:
        logger.error(f"Validation process failed: {e}")
        return

    # Plot validation_metrics
    try:
        plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
        plotter.plot_validation_metrics(validation_metrics)
        logger.info("Validation metrics plotted successfully.")
    except Exception as e:
        logger.error(f"Failed to plot validation metrics: {e}")

    # Run interactive chat
    logger.info("\nStarting interactive chat session...")
    run_interactive_chat(chatbot, quality_checker)

if __name__ == '__main__':
    validate_chatbot()