File size: 4,555 Bytes
5b413d1
 
64e7c31
 
 
5b413d1
 
fc5f33b
5b413d1
 
64e7c31
7a0020b
5b413d1
64e7c31
d7fc7a7
5b413d1
 
 
64e7c31
4aec49f
7a0020b
 
 
64e7c31
7a0020b
 
 
5b413d1
7a0020b
5b413d1
 
7a0020b
64e7c31
7a0020b
 
 
 
 
 
 
 
 
 
64e7c31
 
5b413d1
c7c1b4e
 
5b413d1
64e7c31
5b413d1
64e7c31
cc2577d
5b413d1
64e7c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b413d1
64e7c31
7a0020b
64e7c31
cc2577d
64e7c31
 
7a0020b
64e7c31
5b413d1
 
7a0020b
5b413d1
64e7c31
7a0020b
5b413d1
64e7c31
 
 
 
 
 
5b413d1
 
 
 
 
64e7c31
7a0020b
a763857
 
 
 
 
 
64e7c31
7a0020b
64e7c31
 
c7c1b4e
 
 
 
 
7a0020b
64e7c31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import json
from sentence_transformers import SentenceTransformer
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from plotter import Plotter
from environment_setup import EnvironmentSetup
from logger_config import config_logger
from tf_data_pipeline import TFDataPipeline

logger = config_logger(__name__)

def run_chatbot_validation():
    # Initialize environment
    env = EnvironmentSetup()
    env.initialize()

    MODEL_DIR = "models"
    FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
    FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
    FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")

    # Toggle 'production' or 'test' env
    ENVIRONMENT = "production"
    if ENVIRONMENT == "test":
        FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
        RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
    else:
        FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
        RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")

    # Load the config
    config_path = os.path.join(MODEL_DIR, "config.json")
    if os.path.exists(config_path):
        with open(config_path, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
        config = ChatbotConfig.from_dict(config_dict)
        logger.info(f"Loaded ChatbotConfig from {config_path}")
    else:
        config = ChatbotConfig()
        logger.warning("No config.json found. Using default ChatbotConfig.")

    # Init SentenceTransformer
    try:
        encoder = SentenceTransformer(config.pretrained_model)
        logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
    except Exception as e:
        logger.error(f"Failed to load SentenceTransformer: {e}")
        return

    # Load FAISS index and response pool
    try:
        # Initialize TFDataPipeline
        data_pipeline = TFDataPipeline(
            config=config,
            tokenizer=encoder.tokenizer,
            encoder=encoder,
            response_pool=[],
            query_embeddings_cache={},
            index_type='IndexFlatIP',
            faiss_index_file_path=FAISS_INDEX_PATH
        )

        if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
            logger.error("FAISS index or response pool file is missing.")
            return

        data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
        logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")

        with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
            data_pipeline.response_pool = json.load(f)
            logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
            logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")

        # Validate dimension consistency
        data_pipeline.validate_faiss_index()
        logger.info("FAISS index and response pool validated successfully.")
    except Exception as e:
        logger.error(f"Failed to load or validate FAISS index: {e}")
        return

    # Init QualityChecker and Validator
    try:
        chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
        quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
        validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)  
        logger.info("ResponseQualityChecker and ChatbotValidator initialized.")

        # Run validation
        validation_metrics = validator.run_validation(num_examples=5)
        logger.info(f"Validation Metrics: {validation_metrics}")
    except Exception as e:
        logger.error(f"Validation process failed: {e}")
        return

    # Plot metrics
    try:
        plotter = Plotter(save_dir=env.training_dirs["plots"])
        plotter.plot_validation_metrics(validation_metrics)
        logger.info("Validation metrics plotted successfully.")
    except Exception as e:
        logger.error(f"Failed to plot validation metrics: {e}")

    # Run interactive chat loop
    try:
        logger.info("\nStarting interactive chat session...")
        chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True)
    except Exception as e:
        logger.error(f"Interactive chat session failed: {e}")
    
    
if __name__ == "__main__":
    run_chatbot_validation()