File size: 4,555 Bytes
5b413d1 64e7c31 5b413d1 fc5f33b 5b413d1 64e7c31 7a0020b 5b413d1 64e7c31 d7fc7a7 5b413d1 64e7c31 4aec49f 7a0020b 64e7c31 7a0020b 5b413d1 7a0020b 5b413d1 7a0020b 64e7c31 7a0020b 64e7c31 5b413d1 c7c1b4e 5b413d1 64e7c31 5b413d1 64e7c31 cc2577d 5b413d1 64e7c31 5b413d1 64e7c31 7a0020b 64e7c31 cc2577d 64e7c31 7a0020b 64e7c31 5b413d1 7a0020b 5b413d1 64e7c31 7a0020b 5b413d1 64e7c31 5b413d1 64e7c31 7a0020b a763857 64e7c31 7a0020b 64e7c31 c7c1b4e 7a0020b 64e7c31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import os
import json
from sentence_transformers import SentenceTransformer
from chatbot_config import ChatbotConfig
from chatbot_model import RetrievalChatbot
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from plotter import Plotter
from environment_setup import EnvironmentSetup
from logger_config import config_logger
from tf_data_pipeline import TFDataPipeline
logger = config_logger(__name__)
def run_chatbot_validation():
# Initialize environment
env = EnvironmentSetup()
env.initialize()
MODEL_DIR = "models"
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices")
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index")
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index")
# Toggle 'production' or 'test' env
ENVIRONMENT = "production"
if ENVIRONMENT == "test":
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json")
else:
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json")
# Load the config
config_path = os.path.join(MODEL_DIR, "config.json")
if os.path.exists(config_path):
with open(config_path, "r", encoding="utf-8") as f:
config_dict = json.load(f)
config = ChatbotConfig.from_dict(config_dict)
logger.info(f"Loaded ChatbotConfig from {config_path}")
else:
config = ChatbotConfig()
logger.warning("No config.json found. Using default ChatbotConfig.")
# Init SentenceTransformer
try:
encoder = SentenceTransformer(config.pretrained_model)
logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer: {e}")
return
# Load FAISS index and response pool
try:
# Initialize TFDataPipeline
data_pipeline = TFDataPipeline(
config=config,
tokenizer=encoder.tokenizer,
encoder=encoder,
response_pool=[],
query_embeddings_cache={},
index_type='IndexFlatIP',
faiss_index_file_path=FAISS_INDEX_PATH
)
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
logger.error("FAISS index or response pool file is missing.")
return
data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f:
data_pipeline.response_pool = json.load(f)
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}")
# Validate dimension consistency
data_pipeline.validate_faiss_index()
logger.info("FAISS index and response pool validated successfully.")
except Exception as e:
logger.error(f"Failed to load or validate FAISS index: {e}")
return
# Init QualityChecker and Validator
try:
chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference")
quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline)
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
# Run validation
validation_metrics = validator.run_validation(num_examples=5)
logger.info(f"Validation Metrics: {validation_metrics}")
except Exception as e:
logger.error(f"Validation process failed: {e}")
return
# Plot metrics
try:
plotter = Plotter(save_dir=env.training_dirs["plots"])
plotter.plot_validation_metrics(validation_metrics)
logger.info("Validation metrics plotted successfully.")
except Exception as e:
logger.error(f"Failed to plot validation metrics: {e}")
# Run interactive chat loop
try:
logger.info("\nStarting interactive chat session...")
chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True)
except Exception as e:
logger.error(f"Interactive chat session failed: {e}")
if __name__ == "__main__":
run_chatbot_validation() |