WonderWaffle4
GPT-2 bot with transformers and bot server
59c6d5c
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from .chat_dataset import ChatDataset
from .attention import LuongAttention
from .custom_types import Method
from .constants import BOS_TOKEN
from .vocab import Vocab
from .searchers import GreedySearch
import os
import random
from tqdm import tqdm
class Seq2SeqEncoder(nn.Module):
def __init__(self, input_size: int, hidden_size: int, num_layers: int, embedding: nn.Embedding):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.embedding = embedding
self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True) # batch_first is True, because I don't approve self-harm
def forward(self, x, lengths):
x = self.embedding(x) # Output shape: (batch_size, max_len_in_batch, hidden_size)
packed_embedded = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
outputs, hidden = self.rnn(packed_embedded)
outputs, _ = pad_packed_sequence(outputs, batch_first=True)
return outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:], hidden
class Seq2SeqDecoder(nn.Module):
def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, attn, embedding: nn.Embedding, dropout: int = 0.1):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.num_layers = num_layers
self.attn = attn
self.embedding = embedding
self.embedding_dropout = nn.Dropout(dropout)
self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)
self.concat = nn.Linear(hidden_size * 2, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
def forward(self, x, last_hidden, encoder_outputs):
embedded = self.embedding(x)
embedded = self.embedding_dropout(embedded)
decoder_outputs, hidden = self.rnn(embedded, last_hidden)
attn_weights = self.attn(decoder_outputs, encoder_outputs)
context = attn_weights.bmm(encoder_outputs).squeeze(1)
concat_input = torch.cat((decoder_outputs.squeeze(1), context), 1)
concat_output = torch.tanh(self.concat(concat_input))
output = self.out(concat_output)
output = F.softmax(output, dim=1)
return output, hidden
class Seq2SeqChatbot(nn.Module):
def __init__(self, hidden_size: int, vocab_size: int, encoder_num_layers: int, decoder_num_layers: int, decoder_embedding_dropout: float, device: torch.device):
super().__init__()
self.hidden_size = hidden_size
self.encoder_num_layers = encoder_num_layers
self.decoder_num_layers = decoder_num_layers
self.decoder_embedding_dropout = decoder_embedding_dropout
self.vocab_size = vocab_size
self.epoch = 0
self.device = device
self.vocab = Vocab([])
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.attn = LuongAttention(Method.DOT, hidden_size)
self.encoder = Seq2SeqEncoder(hidden_size, hidden_size, encoder_num_layers, self.embedding)
self.decoder = Seq2SeqDecoder(hidden_size, hidden_size, vocab_size, decoder_num_layers, self.attn, self.embedding, decoder_embedding_dropout)
self.encoder_optimizer = optim.Adam(self.encoder.parameters())
self.decoder_optimizer = optim.Adam(self.decoder.parameters())
self.searcher = GreedySearch(self.encoder, self.decoder, self.embedding, device)
self.to(device)
self.eval_mode()
def train(self, epochs, train_data, teacher_forcing_ratio, device, save_dir, model_name, clip, save_every):
def maskNLLLoss(inp, target, mask):
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
loss = crossEntropy.masked_select(mask).mean()
loss = loss.to(device)
return loss
epoch_progress = tqdm(range(self.epoch, self.epoch + epochs), desc="Training", unit="epoch", leave=True)
epoch_progress.set_description(f"maskNLLLoss: None")
for epoch in epoch_progress:
for x_train, y_train, x_lengths, y_mask in train_data:
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
# Squeeze because batches are made in dataset and DataLoader is only for shuffling
x_train = x_train.squeeze(0).to(device)
y_train = y_train.squeeze(0).to(device)
x_lengths = x_lengths.squeeze(0) # Lengths are computed on CPU
y_mask = y_mask.squeeze(0).to(device)
encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size)
hidden = hidden[:self.decoder_num_layers]
loss = 0
decoder_input = torch.LongTensor([[BOS_TOKEN] for _ in range(y_train.shape[0])])
decoder_input = decoder_input.to(device)
use_teacher_forcing = random.random() < teacher_forcing_ratio
if use_teacher_forcing:
for t in range(y_train.shape[1]): # Process words in all batches for timestep t
decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs)
decoder_input = y_train[:, t].unsqueeze(1)
mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t])
loss += mask_loss
else:
for t in range(y_train.shape[1]):
decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs)
decoder_input = torch.argmax(decoder_outputs, dim=1).unsqueeze(1)
mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t])
loss += mask_loss
loss.backward()
_ = nn.utils.clip_grad_norm_(self.encoder.parameters(), clip)
_ = nn.utils.clip_grad_norm_(self.decoder.parameters(), clip)
self.encoder_optimizer.step()
self.decoder_optimizer.step()
if (epoch % save_every == 0 and epoch != 0) or epoch == save_every - 1:
directory = os.path.join(save_dir, model_name, '{}-{}'.format(self.encoder_num_layers, self.decoder_num_layers, self.hidden_size))
if not os.path.exists(directory):
os.makedirs(directory)
torch.save({
'epoch': epoch + self.epoch,
'en': self.encoder.state_dict(),
'de': self.decoder.state_dict(),
'en_opt': self.encoder_optimizer.state_dict(),
'de_opt': self.decoder_optimizer.state_dict(),
'loss': loss,
'voc_dict': self.vocab.__dict__,
'embedding': self.embedding.state_dict()
}, os.path.join(directory, '{}_{}.tar'.format(epoch, 'checkpoint')))
epoch_progress.set_description(f"maskNLLLoss: {loss:.8f}")
def to(self, device):
self.encoder = self.encoder.to(device)
self.decoder = self.decoder.to(device)
self.embedding = self.embedding.to(device)
self.attn = self.attn.to(device)
def train_mode(self):
self.encoder.train()
self.decoder.train()
self.embedding.train()
self.attn.train()
def eval_mode(self):
self.encoder.eval()
self.decoder.eval()
self.embedding.eval()
self.attn.eval()
def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
encoder_sd = checkpoint["en"]
decoder_sd = checkpoint["de"]
embedding_sd = checkpoint["embedding"]
self.vocab.__dict__ = checkpoint["voc_dict"]
encoder_optimizer_sd = checkpoint["en_opt"]
decoder_optimizer_sd = checkpoint["de_opt"]
self.epoch = checkpoint["epoch"]
self.encoder_optimizer.load_state_dict(encoder_optimizer_sd)
self.decoder_optimizer.load_state_dict(decoder_optimizer_sd)
self.embedding.load_state_dict(embedding_sd)
self.encoder.load_state_dict(encoder_sd)
self.decoder.load_state_dict(decoder_sd)
def forward(self, input_seq: str):
input_seq = ChatDataset._ChatDataset__normalize(input_seq)
input_seq = self.vocab.sentence_indices(input_seq + ["<eos>"]).unsqueeze(0).to(self.device)
output, _ = self.searcher(input_seq, torch.tensor(input_seq.shape[1]).view(1), 10)
output = [self.vocab.index2word[i.item()] for i in output]
output = [word for word in output if word not in ("<bos>", "<eos>", "<pad>")]
return " ".join(output)
if __name__ == "__main__": # Run as module
from .chat_dataset import ChatDataset
import torch.utils.data as data
CHAT_HISTORY_PATH = "models/seq2seq/data/train/chat_history.json"
batch_size = 20
chat_dataset = ChatDataset(CHAT_HISTORY_PATH, max_message_count=10_000, batch_size=batch_size)
train_data = data.DataLoader(chat_dataset, batch_size=1, shuffle=True)
device = torch.device("cpu")
chatbot = Seq2SeqChatbot(500, chat_dataset.vocab.size, 2, 2, 0.1, device)
chatbot.load_checkpoint("models/seq2seq/checkpoint/150_checkpoint.tar")
chatbot.train_mode()
chatbot.train(3, train_data, 0.5, device, "./checkpoint/temp/", "frantics_fox", 50.0, 100)