csc525_retrieval_based_chatbot / validate_model.py
JoeArmani
training and inference updates
5b413d1
raw
history blame
4.45 kB
import os
import json
from chatbot_model import ChatbotConfig, RetrievalChatbot
from response_quality_checker import ResponseQualityChecker
from chatbot_validator import ChatbotValidator
from training_plotter import TrainingPlotter
from environment_setup import EnvironmentSetup
from logger_config import config_logger
logger = config_logger(__name__)
def run_interactive_chat(chatbot, quality_checker):
"""Separate function for interactive chat loop"""
while True:
try:
user_input = input("You: ")
except (KeyboardInterrupt, EOFError):
print("\nAssistant: Goodbye!")
break
if user_input.lower() in ['quit', 'exit', 'bye']:
print("Assistant: Goodbye!")
break
response, candidates, metrics = chatbot.chat(
query=user_input,
conversation_history=None,
quality_checker=quality_checker,
top_k=5
)
print(f"Assistant: {response}")
if metrics.get('is_confident', False):
print("\nAlternative responses:")
for resp, score in candidates[1:4]:
print(f"Score: {score:.4f} - {resp}")
else:
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.")
# TODO:
def validate_chatbot():
# Initialize environment
env = EnvironmentSetup()
env.initialize()
MODEL_DIR = 'models'
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, 'faiss_indices')
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index')
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index')
RESPONSE_POOL_PRODUCTION_PATH = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json')
RESPONSE_POOL_TEST_PATH = FAISS_INDEX_TEST_PATH.replace('.index', '_responses.json')
ENVIRONMENT = 'production' # or 'test'
if ENVIRONMENT == 'test':
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH
RESPONSE_POOL_PATH = RESPONSE_POOL_TEST_PATH
else:
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH
RESPONSE_POOL_PATH = RESPONSE_POOL_PRODUCTION_PATH
# Load config
config = ChatbotConfig()
# Initialize RetrievalChatbot in 'inference' mode
try:
chatbot = RetrievalChatbot(config=config, mode='inference')
logger.info("RetrievalChatbot initialized in 'inference' mode.")
except Exception as e:
logger.error(f"Failed to initialize RetrievalChatbot: {e}")
return
# Ensure FAISS index and response pool are accessible, then load
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH):
logger.error("FAISS index or response pool file is missing.")
return
try:
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH)
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.")
with open(RESPONSE_POOL_PATH, 'r', encoding='utf-8') as f:
chatbot.data_pipeline.response_pool = json.load(f)
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.")
chatbot.data_pipeline.validate_faiss_index()
logger.info("FAISS index and response pool validated successfully.")
except Exception as e:
logger.error(f"Failed to load FAISS index: {e}")
return
# Initialize ResponseQualityChecker and ChatbotValidator
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline)
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker)
logger.info("ResponseQualityChecker and ChatbotValidator initialized.")
# Run validation
try:
validation_metrics = validator.run_validation(num_examples=5)
logger.info(f"Validation Metrics: {validation_metrics}")
except Exception as e:
logger.error(f"Validation process failed: {e}")
return
# Plot validation_metrics
try:
plotter = TrainingPlotter(save_dir=env.training_dirs['plots'])
plotter.plot_validation_metrics(validation_metrics)
logger.info("Validation metrics plotted successfully.")
except Exception as e:
logger.error(f"Failed to plot validation metrics: {e}")
# Run interactive chat
logger.info("\nStarting interactive chat session...")
run_interactive_chat(chatbot, quality_checker)
if __name__ == '__main__':
validate_chatbot()