|
import os |
|
import json |
|
from chatbot_model import ChatbotConfig, RetrievalChatbot |
|
from response_quality_checker import ResponseQualityChecker |
|
from chatbot_validator import ChatbotValidator |
|
from training_plotter import TrainingPlotter |
|
from environment_setup import EnvironmentSetup |
|
|
|
from logger_config import config_logger |
|
logger = config_logger(__name__) |
|
|
|
def run_interactive_chat(chatbot, quality_checker): |
|
"""Separate function for interactive chat loop""" |
|
while True: |
|
try: |
|
user_input = input("You: ") |
|
except (KeyboardInterrupt, EOFError): |
|
print("\nAssistant: Goodbye!") |
|
break |
|
|
|
if user_input.lower() in ['quit', 'exit', 'bye']: |
|
print("Assistant: Goodbye!") |
|
break |
|
|
|
response, candidates, metrics = chatbot.chat( |
|
query=user_input, |
|
conversation_history=None, |
|
quality_checker=quality_checker, |
|
top_k=5 |
|
) |
|
|
|
print(f"Assistant: {response}") |
|
|
|
if metrics.get('is_confident', False): |
|
print("\nAlternative responses:") |
|
for resp, score in candidates[1:4]: |
|
print(f"Score: {score:.4f} - {resp}") |
|
else: |
|
print("\n[Low Confidence]: Consider rephrasing your query for better assistance.") |
|
|
|
|
|
def validate_chatbot(): |
|
|
|
env = EnvironmentSetup() |
|
env.initialize() |
|
|
|
MODEL_DIR = 'models' |
|
FAISS_INDICES_DIR = os.path.join(MODEL_DIR, 'faiss_indices') |
|
FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') |
|
FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_test.index') |
|
RESPONSE_POOL_PRODUCTION_PATH = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json') |
|
RESPONSE_POOL_TEST_PATH = FAISS_INDEX_TEST_PATH.replace('.index', '_responses.json') |
|
ENVIRONMENT = 'production' |
|
if ENVIRONMENT == 'test': |
|
FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH |
|
RESPONSE_POOL_PATH = RESPONSE_POOL_TEST_PATH |
|
else: |
|
FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH |
|
RESPONSE_POOL_PATH = RESPONSE_POOL_PRODUCTION_PATH |
|
|
|
|
|
config = ChatbotConfig() |
|
|
|
|
|
try: |
|
chatbot = RetrievalChatbot(config=config, mode='inference') |
|
logger.info("RetrievalChatbot initialized in 'inference' mode.") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize RetrievalChatbot: {e}") |
|
return |
|
|
|
|
|
if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH): |
|
logger.error("FAISS index or response pool file is missing.") |
|
return |
|
|
|
try: |
|
chatbot.data_pipeline.load_faiss_index(FAISS_INDEX_PATH) |
|
logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.") |
|
|
|
with open(RESPONSE_POOL_PATH, 'r', encoding='utf-8') as f: |
|
chatbot.data_pipeline.response_pool = json.load(f) |
|
logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.") |
|
|
|
chatbot.data_pipeline.validate_faiss_index() |
|
logger.info("FAISS index and response pool validated successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to load FAISS index: {e}") |
|
return |
|
|
|
|
|
quality_checker = ResponseQualityChecker(data_pipeline=chatbot.data_pipeline) |
|
validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker) |
|
logger.info("ResponseQualityChecker and ChatbotValidator initialized.") |
|
|
|
|
|
try: |
|
validation_metrics = validator.run_validation(num_examples=5) |
|
logger.info(f"Validation Metrics: {validation_metrics}") |
|
except Exception as e: |
|
logger.error(f"Validation process failed: {e}") |
|
return |
|
|
|
|
|
try: |
|
plotter = TrainingPlotter(save_dir=env.training_dirs['plots']) |
|
plotter.plot_validation_metrics(validation_metrics) |
|
logger.info("Validation metrics plotted successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to plot validation metrics: {e}") |
|
|
|
|
|
logger.info("\nStarting interactive chat session...") |
|
run_interactive_chat(chatbot, quality_checker) |
|
|
|
if __name__ == '__main__': |
|
validate_chatbot() |