|
|
|
import os |
|
from playsound3 import playsound |
|
import tensorflow |
|
from chatbotTrainer import ChatbotTrainer |
|
import time |
|
import numpy as np |
|
import random |
|
import pdb |
|
import sys |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from preprocessed_dialogs import dialog_data |
|
|
|
|
|
class CorpusTrainer: |
|
def __init__(self): |
|
self.runningTrouble = [] |
|
self.choices_yes = ["yes", "ya", "yeah", "yessir", "yesir", "y", "ye", "yah"] |
|
self.exit_commands = ["exit", "quit", "stop", "x", "q", ""] |
|
|
|
self.log_file = "failure_history.txt" |
|
self.counter = 0 |
|
self.bad_count = 0 |
|
self.top_num = 0 |
|
self.percent_reset = 10.0 |
|
self.time_sleep = 10 |
|
self.processed_dialogs = dialog_data |
|
self.notification_sound = "AlienNotification.mp3" |
|
|
|
|
|
self.chatbot_trainer = None |
|
self.conversation_id = None |
|
self.all_input_texts = [] |
|
self.all_target_texts = [] |
|
self.failure_history = [] |
|
self.speakerList = [] |
|
self.speaker_input_texts = [] |
|
self.speaker_target_texts = [] |
|
self.speakerListData = None |
|
self.troubleListData = None |
|
self.troubleList = [] |
|
self.allTogether = [] |
|
|
|
self.failsafe_trigger = False |
|
|
|
|
|
with open('trained_speakers.txt', 'r') as file: |
|
self.speakerListData = file.read().splitlines() |
|
|
|
with open('troubled_speakers.txt', 'r') as file: |
|
self.troubleListData = file.read().splitlines() |
|
|
|
|
|
self.resetTroubled() |
|
|
|
|
|
self.speakerList = self.cleanupTrained(self.speakerListData) |
|
print("Num GPUs Available: ", len(tensorflow.config.experimental.list_physical_devices('GPU'))) |
|
|
|
|
|
self.corpus_path = '/root/.convokit/saved-corpora/movie-corpus' |
|
self.chatbot_trainer = ChatbotTrainer() |
|
self.chatbot_trainer.load_corpus(self.corpus_path) |
|
|
|
def main(self, chatbot_trainer, user_choice, dialog_data, topConvo=0, top_num=0, play_notification=0): |
|
if play_notification or user_choice in self.exit_commands: |
|
|
|
pass |
|
|
|
if play_notification in self.choices_yes: |
|
play_notification = 1 |
|
elif play_notification not in self.choices_yes: |
|
play_notification = 0 |
|
|
|
|
|
|
|
|
|
self.chatbot_trainer = chatbot_trainer |
|
self.cleanupTroubled() |
|
for x in range(len(self.processed_dialogs.keys())): |
|
topConvo += 1 |
|
self.counter += 1 |
|
randomconvo = random.randint(1, len(self.processed_dialogs.keys())) |
|
speaker = str(randomconvo) |
|
dialog_pairs = self.processed_dialogs[speaker] |
|
|
|
if len(dialog_pairs) < 3: |
|
print(f"Conversation {speaker} skipped for NOT providing enough data... ") |
|
continue |
|
|
|
|
|
for input_text, target_text in dialog_pairs: |
|
self.speaker_input_texts = [] |
|
self.speaker_target_texts = [] |
|
input_shape = np.array(input_text).shape |
|
target_shape = np.array(target_text).shape |
|
if input_shape in [(1, 64), (1, 63)] or target_shape in [(1, 64), (1, 63)]: |
|
print(f"Conversation {speaker} skipped for NOT providing properly shaped data... ") |
|
continue |
|
|
|
if len(input_text) < 3 or len(target_text) < 3: |
|
print(f"Conversation {speaker} skipped for NOT providing enough data... ") |
|
continue |
|
|
|
if input_text != "" and target_text != "": |
|
self.speaker_input_texts.append(input_text.strip()) |
|
self.all_input_texts.append(input_text.strip()) |
|
self.speaker_target_texts.append(target_text.strip()) |
|
self.all_target_texts.append(target_text.strip()) |
|
|
|
|
|
if self.failsafe_trigger is False: |
|
if speaker not in self.speakerList: |
|
self.conversation_id = int(speaker) |
|
if self.conversation_id > self.top_num: |
|
self.top_num = self.conversation_id |
|
|
|
print(f"Conversation: {self.conversation_id}") |
|
|
|
|
|
limit = self.chatbot_trainer.early_patience - 3 |
|
|
|
|
|
if self.chatbot_trainer.tokenizer.num_words > self.chatbot_trainer.max_vocabulary: |
|
print("MAXIMUM Vocabulary Reached! Quitting Now... ") |
|
|
|
if play_notification == 1: |
|
playsound(notification_sound) |
|
|
|
return self.chatbot_trainer, user_choice, dialog_data, topConvo, self.top_num, self.failsafe_trigger |
|
|
|
data = [input_text, target_text] |
|
|
|
try: |
|
|
|
if user_choice in self.choices_yes and play_notification in self.choices_yes: |
|
self.user_yes(speaker=speaker, data=data, limit=limit, play_notification=play_notification) |
|
|
|
|
|
elif user_choice in self.choices_yes and play_notification not in self.choices_yes: |
|
self.user_yes(speaker=speaker, data=data, limit=limit, play_notification=play_notification) |
|
|
|
|
|
elif user_choice not in self.choices_yes and play_notification not in self.choices_yes: |
|
self.user_no(speaker=speaker, data=data, limit=limit, play_notification=play_notification) |
|
|
|
except ValueError: |
|
print("Skipped Conversation {speaker}... Trying again...") |
|
continue |
|
|
|
except Exception as e: |
|
print(e) |
|
|
|
def user_yes(self, data, speaker, limit, play_notification): |
|
self.chatbot_trainer.train_model(data[0], data[1], str(self.conversation_id), speaker) |
|
self.runningTrouble = self.chatbot_trainer.running_trouble |
|
if speaker not in self.speakerList and len(self.runningTrouble) < limit: |
|
self.speakerList.append(speaker) |
|
|
|
with open("trained_speakers.txt", 'a') as f: |
|
f.write(f"{speaker}\n") |
|
|
|
elif len(self.runningTrouble) > limit: |
|
self.bad_count += 1 |
|
self.troubleList.append(speaker) |
|
|
|
with open("troubled_speakers.txt", 'a') as f: |
|
f.write(f"{speaker}\n") |
|
|
|
self.allTogether = self.resetTogether() |
|
topConvo = len(self.allTogether) |
|
self.bad_count = len(self.troubleList) |
|
|
|
|
|
|
|
|
|
|
|
|
|
percent_running = self.runningPercent(len(self.troubleList), self.counter) |
|
self.failure_history.append(len(self.troubleList)) |
|
if percent_running is None: |
|
percent_running = 0.0 |
|
self.chatbot_trainer.logger.info(f"Running Percentage Failure: {percent_running}%") |
|
|
|
|
|
|
|
print(f"Now is the time to quit if need be... ") |
|
if play_notification == 1: |
|
playsound(notification_sound) |
|
|
|
if percent_running is not None: |
|
|
|
if percent_running > self.percent_reset: |
|
print("Logging Failures... Resetting... Failure Rate is Greater than {self.percent_reset}%...") |
|
answer_1 = input("Show Failures for this Run? \n>") |
|
if answer_1 in self.exit_commands: |
|
quit() |
|
show_file = True if answer_1 in self.choices_yes else False |
|
answer_2 = input("Save Failures for this Run? \n>") |
|
if answer_2 in self.exit_commands: |
|
quit() |
|
save_file = True if answer_1 in self.choices_yes else False |
|
self.log_failures(len(self.troubleList), self.log_file) |
|
self.plot_failures(self.log_file, show_file=show_file, save_file=save_file) |
|
print("Plotting Failures... See failures_plot.png for more information... ") |
|
|
|
delete_speakers = input("Would you like to clear trained_speakers.txt? \nThis is useful for touching on successful conversations... \n>") |
|
if delete_speakers in self.choices_yes: |
|
with open('trained_speakers.txt', 'w') as f: |
|
f.write("") |
|
|
|
|
|
input('Enter to Continue... (This will reset the run) ') |
|
return self.main(self.chatbot_trainer, user_choice, dialog_data, topConvo, self.top_num) |
|
|
|
|
|
input("\nEnter to Continue... ") |
|
|
|
def user_no(self, data, speaker, limit, play_notification): |
|
self.chatbot_trainer.train_model(data[0], data[1], str(self.conversation_id), speaker) |
|
if speaker not in self.speakerList and len(self.runningTrouble) < limit: |
|
self.speakerList.append(speaker) |
|
|
|
with open("trained_speakers.txt", 'a') as f: |
|
f.write(f"{speaker}\n") |
|
|
|
elif len(self.runningTrouble) > limit: |
|
self.bad_count += 1 |
|
self.troubleList.append(speaker) |
|
|
|
with open("troubled_speakers.txt", 'a') as f: |
|
f.write(f"{speaker}\n") |
|
|
|
|
|
self.allTogether = self.resetTogether() |
|
topConvo = len(self.allTogether) |
|
self.bad_count = len(self.troubleList) |
|
|
|
|
|
print(f"Trouble List: {len(self.troubleList)}") |
|
print(f"Bad Count: {self.bad_count}") |
|
print(f"Number of Conversations(This Run): {self.counter}") |
|
print(f"Number of Conversations Combined: {topConvo}") |
|
print(f"Running Trouble: {len(self.runningTrouble)}") |
|
|
|
percent_running = self.runningPercent(len(self.troubleList), topConvo) |
|
self.failure_history.append(len(self.troubleList)) |
|
if percent_running is None: |
|
percent_running = 0.0 |
|
self.chatbot_trainer.logger.info(f"Running Percentage Failure: {percent_running}%") |
|
|
|
print(f"Now is the time to quit if need be... ") |
|
if play_notification == 1: |
|
playsound(notification_sound) |
|
for x in range(self.time_sleep): |
|
|
|
print(f"Next convo in:{self.time_sleep-x}") |
|
|
|
if percent_running is not None: |
|
|
|
if percent_running > self.percent_reset: |
|
self.log_failures(len(self.troubleList), self.log_file) |
|
print("Plotting Failures... See failures_plot.png for more information... ") |
|
self.plot_failures(self.log_file) |
|
if play_notification == 1: |
|
playsound(notification_sound) |
|
print(f"Resetting... Failure Rate is Greater than {self.percent_reset}%... For this run.") |
|
|
|
return self.main(self.chatbot_trainer, user_choice, dialog_data, topConvo, self.top_num) |
|
|
|
def resetTogether(self): |
|
for speakers in self.speakerList: |
|
if speakers not in self.allTogether: |
|
self.allTogether.append(str(speakers)) |
|
for speakers in self.troubleListData: |
|
if speakers not in self.allTogether: |
|
self.allTogether.append(str(speakers)) |
|
|
|
allTogetherSorted = sorted(self.allTogether) |
|
|
|
return allTogetherSorted |
|
|
|
def cleanupTrained(self, speakerList): |
|
for data in self.speakerList: |
|
data = data.strip('\n') |
|
if data not in self.speakerList and data not in self.troubleListData: |
|
self.speakerList.append(data) |
|
with open('trained_speakers.txt', 'w') as f: |
|
for speakers in self.speakerList: |
|
f.write(f"{speakers}\n") |
|
|
|
self.speakerList = sorted(self.speakerList) |
|
return self.speakerList |
|
|
|
def resetTroubled(self): |
|
os.remove('troubled_speakers.txt') |
|
with open('troubled_speakers.txt', 'w') as f: |
|
f.write("") |
|
|
|
def cleanupTroubled(self): |
|
tempBin = [] |
|
with open('troubled_speakers.txt', 'r') as fr: |
|
data = fr.readlines() |
|
for lines in data: |
|
if lines not in tempBin: |
|
tempBin.append(str(lines).strip('\n')) |
|
|
|
tempBin = sorted(tempBin) |
|
with open('troubled_speakers.txt', 'w') as fw: |
|
fw.write("") |
|
for troubled in tempBin: |
|
fw.write(f"{troubled}\n") |
|
|
|
def runningPercent(self, list1, list2): |
|
if list1 > 0 and list2 > 0: |
|
x = list1 / list2 |
|
percentage = x * 100 |
|
percentage = round(percentage, 2) |
|
|
|
return percentage |
|
|
|
elif list1 == 0: |
|
percentage = 0.0 |
|
return percentage |
|
|
|
def plot_failures(self, log_file, show_file=False, save_file=False): |
|
|
|
if not os.path.exists(log_file): |
|
print("No failure data found.") |
|
return |
|
|
|
with open("failure_history.txt", "r") as f: |
|
self.failure_history = [int(line.strip()) for line in f.readlines()] |
|
|
|
if len(self.failure_history) == 0: |
|
print("No failure data to plot.") |
|
return |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(self.failure_history, marker='o', linestyle='-', color='red', label='Failures Per Run') |
|
plt.xlabel("Run Iteration") |
|
plt.ylabel("Number of Failures") |
|
plt.title("Failures Before Restart Over Time") |
|
plt.legend() |
|
plt.grid(True) |
|
|
|
if save_file: |
|
|
|
plt.savefig("failures_plot.png") |
|
|
|
if show_file: |
|
plt.show() |
|
|
|
def log_failures(self, num_failures, log_file): |
|
|
|
with open(log_file, "a") as f: |
|
f.write(f"{num_failures}\n") |
|
|
|
print(f"Logged {num_failures} failures.") |
|
|
|
def run(): |
|
app = CorpusTrainer() |
|
user_choice = input(f"Run Supervised?({app.chatbot_trainer.model_filename})\n>") |
|
play_notification = input(f"Would you like to play a notification after each training?\nHelps with manual stopping before max_vocabulary reached... \n>") |
|
app.main(chatbot_trainer=app.chatbot_trainer, user_choice=user_choice, dialog_data=dialog_data, play_notification=play_notification) |
|
|
|
|
|
if __name__ == "__main__": |
|
while True: |
|
run() |