import os import json from chatbot_model import ChatbotConfig, RetrievalChatbot from response_quality_checker import ResponseQualityChecker from chatbot_validator import ChatbotValidator from plotter import Plotter from environment_setup import EnvironmentSetup from logger_config import config_logger logger = config_logger(__name__) def run_interactive_chat(chatbot, quality_checker): """Separate function for interactive chat loop.""" while True: try: user_input = input("You: ") except (KeyboardInterrupt, EOFError): print("\nAssistant: Goodbye!") break if user_input.lower() in ["quit", "exit", "bye"]: print("Assistant: Goodbye!") break response, candidates, metrics = chatbot.chat( query=user_input, conversation_history=None, quality_checker=quality_checker, top_k=10 ) print(f"Assistant: {response}") # Show alternative responses if confident if metrics.get("is_confident", False): print("\nAlternative responses:") for resp, score in candidates[1:4]: print(f"Score: {score:.4f} - {resp}") else: print("\n[Low Confidence]: Consider rephrasing your query for better assistance.") def validate_chatbot(): # Initialize environment env = EnvironmentSetup() env.initialize() MODEL_DIR = "new_iteration/data_prep_iterative_models" FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index") # Toggle 'production' or 'test' env ENVIRONMENT = "production" if ENVIRONMENT == "test": FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json") else: FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") # Load the config config_path = os.path.join(MODEL_DIR, "config.json") if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as f: config_dict = json.load(f) config = ChatbotConfig.from_dict(config_dict) logger.info(f"Loaded ChatbotConfig from {config_path}") else: config = ChatbotConfig() logger.warning("No config.json found. Using default ChatbotConfig.") # Load RetrievalChatbot in 'inference' mode using the classmethod try: chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") logger.info("RetrievalChatbot loaded in 'inference' mode successfully.") except Exception as e: logger.error(f"Failed to load RetrievalChatbot: {e}") return # Confirm FAISS index & response pool exist if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH): logger.error("FAISS index or response pool file is missing.") return # Load specific FAISS index and response pool try: # Even though load_model might auto-load an index, we override here with the specific file chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH) logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.") print("FAISS dimensions:", chatbot.data_pipeline.index.d) print("FAISS index type:", type(chatbot.data_pipeline.index)) print("FAISS index total vectors:", chatbot.data_pipeline.index.ntotal) print("FAISS is_trained:", chatbot.data_pipeline.index.is_trained) with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: chatbot.data_pipeline.response_pool = json.load(f) logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.") print("\nTotal responses in pool:", len(chatbot.data_pipeline.response_pool)) # Validate dimension consistency chatbot.data_pipeline.validate_faiss_index() logger.info("FAISS index and response pool validated successfully.") except Exception as e: logger.error(f"Failed to load or validate FAISS index: {e}") return # Init QualityChecker and Validator quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline) validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker) logger.info("ResponseQualityChecker and ChatbotValidator initialized.") # Run validation try: validation_metrics = validator.run_validation(num_examples=5) logger.info(f"Validation Metrics: {validation_metrics}") except Exception as e: logger.error(f"Validation process failed: {e}") return # Plot metrics # try: # plotter = Plotter(save_dir=env.training_dirs["plots"]) # plotter.plot_validation_metrics(validation_metrics) # logger.info("Validation metrics plotted successfully.") # except Exception as e: # logger.error(f"Failed to plot validation metrics: {e}") # Run interactive chat loop # logger.info("\nStarting interactive chat session...") # run_interactive_chat(chatbot, quality_checker) if __name__ == "__main__": validate_chatbot()