import os import json from sentence_transformers import SentenceTransformer from chatbot_config import ChatbotConfig from chatbot_model import RetrievalChatbot from response_quality_checker import ResponseQualityChecker from chatbot_validator import ChatbotValidator from plotter import Plotter from environment_setup import EnvironmentSetup from logger_config import config_logger from tf_data_pipeline import TFDataPipeline logger = config_logger(__name__) def run_chatbot_validation(): # Initialize environment env = EnvironmentSetup() env.initialize() MODEL_DIR = "models" FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index") # Toggle 'production' or 'test' env ENVIRONMENT = "production" if ENVIRONMENT == "test": FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json") else: FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") # Load the config config_path = os.path.join(MODEL_DIR, "config.json") if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as f: config_dict = json.load(f) config = ChatbotConfig.from_dict(config_dict) logger.info(f"Loaded ChatbotConfig from {config_path}") else: config = ChatbotConfig() logger.warning("No config.json found. Using default ChatbotConfig.") # Init SentenceTransformer try: encoder = SentenceTransformer(config.pretrained_model) logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}") except Exception as e: logger.error(f"Failed to load SentenceTransformer: {e}") return # Load FAISS index and response pool try: # Initialize TFDataPipeline data_pipeline = TFDataPipeline( config=config, tokenizer=encoder.tokenizer, encoder=encoder, response_pool=[], query_embeddings_cache={}, index_type='IndexFlatIP', faiss_index_file_path=FAISS_INDEX_PATH ) if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH): logger.error("FAISS index or response pool file is missing.") return data_pipeline.load_faiss_index(FAISS_INDEX_PATH) logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.") with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: data_pipeline.response_pool = json.load(f) logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.") logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}") # Validate dimension consistency data_pipeline.validate_faiss_index() logger.info("FAISS index and response pool validated successfully.") except Exception as e: logger.error(f"Failed to load or validate FAISS index: {e}") return # Init QualityChecker and Validator try: chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline) validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker) logger.info("ResponseQualityChecker and ChatbotValidator initialized.") # Run validation validation_metrics = validator.run_validation(num_examples=5) logger.info(f"Validation Metrics: {validation_metrics}") except Exception as e: logger.error(f"Validation process failed: {e}") return # Plot metrics try: plotter = Plotter(save_dir=env.training_dirs["plots"]) plotter.plot_validation_metrics(validation_metrics) logger.info("Validation metrics plotted successfully.") except Exception as e: logger.error(f"Failed to plot validation metrics: {e}") # Run interactive chat loop try: logger.info("\nStarting interactive chat session...") chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True) except Exception as e: logger.error(f"Interactive chat session failed: {e}") if __name__ == "__main__": run_chatbot_validation()