|
import tensorflow as tf |
|
from chatbot_model import RetrievalChatbot, ChatbotConfig |
|
from environment_setup import EnvironmentSetup |
|
from response_quality_checker import ResponseQualityChecker |
|
from chatbot_validator import ChatbotValidator |
|
from training_plotter import TrainingPlotter |
|
|
|
|
|
from logger_config import config_logger |
|
logger = config_logger(__name__) |
|
|
|
def run_interactive_chat(chatbot, quality_checker): |
|
"""Separate function for interactive chat loop""" |
|
while True: |
|
user_input = input("You: ") |
|
if user_input.lower() in ['quit', 'exit', 'bye']: |
|
print("Assistant: Goodbye!") |
|
break |
|
|
|
response, candidates, metrics = chatbot.chat( |
|
query=user_input, |
|
conversation_history=None, |
|
quality_checker=quality_checker, |
|
top_k=5 |
|
) |
|
|
|
print(f"Assistant: {response}") |
|
|
|
if metrics.get('is_confident', False): |
|
print("\nAlternative responses:") |
|
for resp, score in candidates[1:4]: |
|
print(f"Score: {score:.4f} - {resp}") |
|
|
|
def main(): |
|
|
|
tf.keras.backend.clear_session() |
|
env = EnvironmentSetup() |
|
env.initialize() |
|
|
|
DEBUG_SAMPLES = 5 |
|
EPOCHS = 5 if DEBUG_SAMPLES else 20 |
|
TRAINING_DATA_PATH = 'processed_outputs/batch_group_0010.json' |
|
|
|
|
|
batch_size = env.optimize_batch_size(base_batch_size=16) |
|
|
|
|
|
config = ChatbotConfig( |
|
embedding_dim=768, |
|
max_context_token_limit=512, |
|
freeze_embeddings=False, |
|
) |
|
|
|
|
|
dialogues = RetrievalChatbot.load_training_data(data_path=TRAINING_DATA_PATH, debug_samples=DEBUG_SAMPLES) |
|
print(dialogues) |
|
|
|
|
|
|
|
chatbot = RetrievalChatbot(config, dialogues) |
|
chatbot.build_models() |
|
chatbot.verify_faiss_index() |
|
|
|
chatbot.train_streaming( |
|
dialogues=dialogues, |
|
epochs=EPOCHS, |
|
batch_size=batch_size, |
|
use_lr_schedule=True, |
|
) |
|
|
|
|
|
model_save_path = env.training_dirs['base'] / 'final_model' |
|
chatbot.save_models(model_save_path) |
|
|
|
|
|
quality_checker = ResponseQualityChecker(chatbot=chatbot) |
|
validator = ChatbotValidator(chatbot, quality_checker) |
|
validation_metrics = validator.run_validation(num_examples=5) |
|
logger.info(f"Validation Metrics: {validation_metrics}") |
|
|
|
|
|
plotter = TrainingPlotter(save_dir=env.training_dirs['plots']) |
|
plotter.plot_training_history(chatbot.history) |
|
plotter.plot_validation_metrics(validation_metrics) |
|
|
|
|
|
logger.info("\nStarting interactive chat session...") |
|
run_interactive_chat(chatbot, quality_checker) |
|
|
|
if __name__ == "__main__": |
|
main() |