import torch from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence import torch.nn.functional as F import torch.nn as nn import torch.optim as optim from .chat_dataset import ChatDataset from .attention import LuongAttention from .custom_types import Method from .constants import BOS_TOKEN from .vocab import Vocab from .searchers import GreedySearch import os import random from tqdm import tqdm class Seq2SeqEncoder(nn.Module): def __init__(self, input_size: int, hidden_size: int, num_layers: int, embedding: nn.Embedding): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.embedding = embedding self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True) # batch_first is True, because I don't approve self-harm def forward(self, x, lengths): x = self.embedding(x) # Output shape: (batch_size, max_len_in_batch, hidden_size) packed_embedded = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) outputs, hidden = self.rnn(packed_embedded) outputs, _ = pad_packed_sequence(outputs, batch_first=True) return outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:], hidden class Seq2SeqDecoder(nn.Module): def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, attn, embedding: nn.Embedding, dropout: int = 0.1): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.num_layers = num_layers self.attn = attn self.embedding = embedding self.embedding_dropout = nn.Dropout(dropout) self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True) self.concat = nn.Linear(hidden_size * 2, hidden_size) self.out = nn.Linear(hidden_size, output_size) def forward(self, x, last_hidden, encoder_outputs): embedded = self.embedding(x) embedded = self.embedding_dropout(embedded) decoder_outputs, hidden = self.rnn(embedded, last_hidden) attn_weights = self.attn(decoder_outputs, encoder_outputs) context = attn_weights.bmm(encoder_outputs).squeeze(1) concat_input = torch.cat((decoder_outputs.squeeze(1), context), 1) concat_output = torch.tanh(self.concat(concat_input)) output = self.out(concat_output) output = F.softmax(output, dim=1) return output, hidden class Seq2SeqChatbot(nn.Module): def __init__(self, hidden_size: int, vocab_size: int, encoder_num_layers: int, decoder_num_layers: int, decoder_embedding_dropout: float, device: torch.device): super().__init__() self.hidden_size = hidden_size self.encoder_num_layers = encoder_num_layers self.decoder_num_layers = decoder_num_layers self.decoder_embedding_dropout = decoder_embedding_dropout self.vocab_size = vocab_size self.epoch = 0 self.device = device self.vocab = Vocab([]) self.embedding = nn.Embedding(vocab_size, hidden_size) self.attn = LuongAttention(Method.DOT, hidden_size) self.encoder = Seq2SeqEncoder(hidden_size, hidden_size, encoder_num_layers, self.embedding) self.decoder = Seq2SeqDecoder(hidden_size, hidden_size, vocab_size, decoder_num_layers, self.attn, self.embedding, decoder_embedding_dropout) self.encoder_optimizer = optim.Adam(self.encoder.parameters()) self.decoder_optimizer = optim.Adam(self.decoder.parameters()) self.searcher = GreedySearch(self.encoder, self.decoder, self.embedding, device) self.to(device) self.eval_mode() def train(self, epochs, train_data, teacher_forcing_ratio, device, save_dir, model_name, clip, save_every): def maskNLLLoss(inp, target, mask): crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1)) loss = crossEntropy.masked_select(mask).mean() loss = loss.to(device) return loss epoch_progress = tqdm(range(self.epoch, self.epoch + epochs), desc="Training", unit="epoch", leave=True) epoch_progress.set_description(f"maskNLLLoss: None") for epoch in epoch_progress: for x_train, y_train, x_lengths, y_mask in train_data: self.encoder_optimizer.zero_grad() self.decoder_optimizer.zero_grad() # Squeeze because batches are made in dataset and DataLoader is only for shuffling x_train = x_train.squeeze(0).to(device) y_train = y_train.squeeze(0).to(device) x_lengths = x_lengths.squeeze(0) # Lengths are computed on CPU y_mask = y_mask.squeeze(0).to(device) encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size) hidden = hidden[:self.decoder_num_layers] loss = 0 decoder_input = torch.LongTensor([[BOS_TOKEN] for _ in range(y_train.shape[0])]) decoder_input = decoder_input.to(device) use_teacher_forcing = random.random() < teacher_forcing_ratio if use_teacher_forcing: for t in range(y_train.shape[1]): # Process words in all batches for timestep t decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs) decoder_input = y_train[:, t].unsqueeze(1) mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t]) loss += mask_loss else: for t in range(y_train.shape[1]): decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs) decoder_input = torch.argmax(decoder_outputs, dim=1).unsqueeze(1) mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t]) loss += mask_loss loss.backward() _ = nn.utils.clip_grad_norm_(self.encoder.parameters(), clip) _ = nn.utils.clip_grad_norm_(self.decoder.parameters(), clip) self.encoder_optimizer.step() self.decoder_optimizer.step() if (epoch % save_every == 0 and epoch != 0) or epoch == save_every - 1: directory = os.path.join(save_dir, model_name, '{}-{}'.format(self.encoder_num_layers, self.decoder_num_layers, self.hidden_size)) if not os.path.exists(directory): os.makedirs(directory) torch.save({ 'epoch': epoch + self.epoch, 'en': self.encoder.state_dict(), 'de': self.decoder.state_dict(), 'en_opt': self.encoder_optimizer.state_dict(), 'de_opt': self.decoder_optimizer.state_dict(), 'loss': loss, 'voc_dict': self.vocab.__dict__, 'embedding': self.embedding.state_dict() }, os.path.join(directory, '{}_{}.tar'.format(epoch, 'checkpoint'))) epoch_progress.set_description(f"maskNLLLoss: {loss:.8f}") def to(self, device): self.encoder = self.encoder.to(device) self.decoder = self.decoder.to(device) self.embedding = self.embedding.to(device) self.attn = self.attn.to(device) def train_mode(self): self.encoder.train() self.decoder.train() self.embedding.train() self.attn.train() def eval_mode(self): self.encoder.eval() self.decoder.eval() self.embedding.eval() self.attn.eval() def load_checkpoint(self, checkpoint_path: str): checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) encoder_sd = checkpoint["en"] decoder_sd = checkpoint["de"] embedding_sd = checkpoint["embedding"] self.vocab.__dict__ = checkpoint["voc_dict"] encoder_optimizer_sd = checkpoint["en_opt"] decoder_optimizer_sd = checkpoint["de_opt"] self.epoch = checkpoint["epoch"] self.encoder_optimizer.load_state_dict(encoder_optimizer_sd) self.decoder_optimizer.load_state_dict(decoder_optimizer_sd) self.embedding.load_state_dict(embedding_sd) self.encoder.load_state_dict(encoder_sd) self.decoder.load_state_dict(decoder_sd) def forward(self, input_seq: str): input_seq = ChatDataset._ChatDataset__normalize(input_seq) input_seq = self.vocab.sentence_indices(input_seq + [""]).unsqueeze(0).to(self.device) output, _ = self.searcher(input_seq, torch.tensor(input_seq.shape[1]).view(1), 10) output = [self.vocab.index2word[i.item()] for i in output] output = [word for word in output if word not in ("", "", "")] return " ".join(output) if __name__ == "__main__": # Run as module from .chat_dataset import ChatDataset import torch.utils.data as data CHAT_HISTORY_PATH = "models/seq2seq/data/train/chat_history.json" batch_size = 20 chat_dataset = ChatDataset(CHAT_HISTORY_PATH, max_message_count=10_000, batch_size=batch_size) train_data = data.DataLoader(chat_dataset, batch_size=1, shuffle=True) device = torch.device("cpu") chatbot = Seq2SeqChatbot(500, chat_dataset.vocab.size, 2, 2, 0.1, device) chatbot.load_checkpoint("models/seq2seq/checkpoint/150_checkpoint.tar") chatbot.train_mode() chatbot.train(3, train_data, 0.5, device, "./checkpoint/temp/", "frantics_fox", 50.0, 100)