Spaces:
Running
Running
import torch | |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch.optim as optim | |
from .chat_dataset import ChatDataset | |
from .attention import LuongAttention | |
from .custom_types import Method | |
from .constants import BOS_TOKEN | |
from .vocab import Vocab | |
from .searchers import GreedySearch | |
import os | |
import random | |
from tqdm import tqdm | |
class Seq2SeqEncoder(nn.Module): | |
def __init__(self, input_size: int, hidden_size: int, num_layers: int, embedding: nn.Embedding): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.embedding = embedding | |
self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True) # batch_first is True, because I don't approve self-harm | |
def forward(self, x, lengths): | |
x = self.embedding(x) # Output shape: (batch_size, max_len_in_batch, hidden_size) | |
packed_embedded = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False) | |
outputs, hidden = self.rnn(packed_embedded) | |
outputs, _ = pad_packed_sequence(outputs, batch_first=True) | |
return outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:], hidden | |
class Seq2SeqDecoder(nn.Module): | |
def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, attn, embedding: nn.Embedding, dropout: int = 0.1): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.output_size = output_size | |
self.num_layers = num_layers | |
self.attn = attn | |
self.embedding = embedding | |
self.embedding_dropout = nn.Dropout(dropout) | |
self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True) | |
self.concat = nn.Linear(hidden_size * 2, hidden_size) | |
self.out = nn.Linear(hidden_size, output_size) | |
def forward(self, x, last_hidden, encoder_outputs): | |
embedded = self.embedding(x) | |
embedded = self.embedding_dropout(embedded) | |
decoder_outputs, hidden = self.rnn(embedded, last_hidden) | |
attn_weights = self.attn(decoder_outputs, encoder_outputs) | |
context = attn_weights.bmm(encoder_outputs).squeeze(1) | |
concat_input = torch.cat((decoder_outputs.squeeze(1), context), 1) | |
concat_output = torch.tanh(self.concat(concat_input)) | |
output = self.out(concat_output) | |
output = F.softmax(output, dim=1) | |
return output, hidden | |
class Seq2SeqChatbot(nn.Module): | |
def __init__(self, hidden_size: int, vocab_size: int, encoder_num_layers: int, decoder_num_layers: int, decoder_embedding_dropout: float, device: torch.device): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.encoder_num_layers = encoder_num_layers | |
self.decoder_num_layers = decoder_num_layers | |
self.decoder_embedding_dropout = decoder_embedding_dropout | |
self.vocab_size = vocab_size | |
self.epoch = 0 | |
self.device = device | |
self.vocab = Vocab([]) | |
self.embedding = nn.Embedding(vocab_size, hidden_size) | |
self.attn = LuongAttention(Method.DOT, hidden_size) | |
self.encoder = Seq2SeqEncoder(hidden_size, hidden_size, encoder_num_layers, self.embedding) | |
self.decoder = Seq2SeqDecoder(hidden_size, hidden_size, vocab_size, decoder_num_layers, self.attn, self.embedding, decoder_embedding_dropout) | |
self.encoder_optimizer = optim.Adam(self.encoder.parameters()) | |
self.decoder_optimizer = optim.Adam(self.decoder.parameters()) | |
self.searcher = GreedySearch(self.encoder, self.decoder, self.embedding, device) | |
self.to(device) | |
self.eval_mode() | |
def train(self, epochs, train_data, teacher_forcing_ratio, device, save_dir, model_name, clip, save_every): | |
def maskNLLLoss(inp, target, mask): | |
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1)) | |
loss = crossEntropy.masked_select(mask).mean() | |
loss = loss.to(device) | |
return loss | |
epoch_progress = tqdm(range(self.epoch, self.epoch + epochs), desc="Training", unit="epoch", leave=True) | |
epoch_progress.set_description(f"maskNLLLoss: None") | |
for epoch in epoch_progress: | |
for x_train, y_train, x_lengths, y_mask in train_data: | |
self.encoder_optimizer.zero_grad() | |
self.decoder_optimizer.zero_grad() | |
# Squeeze because batches are made in dataset and DataLoader is only for shuffling | |
x_train = x_train.squeeze(0).to(device) | |
y_train = y_train.squeeze(0).to(device) | |
x_lengths = x_lengths.squeeze(0) # Lengths are computed on CPU | |
y_mask = y_mask.squeeze(0).to(device) | |
encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size) | |
hidden = hidden[:self.decoder_num_layers] | |
loss = 0 | |
decoder_input = torch.LongTensor([[BOS_TOKEN] for _ in range(y_train.shape[0])]) | |
decoder_input = decoder_input.to(device) | |
use_teacher_forcing = random.random() < teacher_forcing_ratio | |
if use_teacher_forcing: | |
for t in range(y_train.shape[1]): # Process words in all batches for timestep t | |
decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs) | |
decoder_input = y_train[:, t].unsqueeze(1) | |
mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t]) | |
loss += mask_loss | |
else: | |
for t in range(y_train.shape[1]): | |
decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs) | |
decoder_input = torch.argmax(decoder_outputs, dim=1).unsqueeze(1) | |
mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t]) | |
loss += mask_loss | |
loss.backward() | |
_ = nn.utils.clip_grad_norm_(self.encoder.parameters(), clip) | |
_ = nn.utils.clip_grad_norm_(self.decoder.parameters(), clip) | |
self.encoder_optimizer.step() | |
self.decoder_optimizer.step() | |
if (epoch % save_every == 0 and epoch != 0) or epoch == save_every - 1: | |
directory = os.path.join(save_dir, model_name, '{}-{}'.format(self.encoder_num_layers, self.decoder_num_layers, self.hidden_size)) | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
torch.save({ | |
'epoch': epoch + self.epoch, | |
'en': self.encoder.state_dict(), | |
'de': self.decoder.state_dict(), | |
'en_opt': self.encoder_optimizer.state_dict(), | |
'de_opt': self.decoder_optimizer.state_dict(), | |
'loss': loss, | |
'voc_dict': self.vocab.__dict__, | |
'embedding': self.embedding.state_dict() | |
}, os.path.join(directory, '{}_{}.tar'.format(epoch, 'checkpoint'))) | |
epoch_progress.set_description(f"maskNLLLoss: {loss:.8f}") | |
def to(self, device): | |
self.encoder = self.encoder.to(device) | |
self.decoder = self.decoder.to(device) | |
self.embedding = self.embedding.to(device) | |
self.attn = self.attn.to(device) | |
def train_mode(self): | |
self.encoder.train() | |
self.decoder.train() | |
self.embedding.train() | |
self.attn.train() | |
def eval_mode(self): | |
self.encoder.eval() | |
self.decoder.eval() | |
self.embedding.eval() | |
self.attn.eval() | |
def load_checkpoint(self, checkpoint_path: str): | |
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) | |
encoder_sd = checkpoint["en"] | |
decoder_sd = checkpoint["de"] | |
embedding_sd = checkpoint["embedding"] | |
self.vocab.__dict__ = checkpoint["voc_dict"] | |
encoder_optimizer_sd = checkpoint["en_opt"] | |
decoder_optimizer_sd = checkpoint["de_opt"] | |
self.epoch = checkpoint["epoch"] | |
self.encoder_optimizer.load_state_dict(encoder_optimizer_sd) | |
self.decoder_optimizer.load_state_dict(decoder_optimizer_sd) | |
self.embedding.load_state_dict(embedding_sd) | |
self.encoder.load_state_dict(encoder_sd) | |
self.decoder.load_state_dict(decoder_sd) | |
def forward(self, input_seq: str): | |
input_seq = ChatDataset._ChatDataset__normalize(input_seq) | |
input_seq = self.vocab.sentence_indices(input_seq + ["<eos>"]).unsqueeze(0).to(self.device) | |
output, _ = self.searcher(input_seq, torch.tensor(input_seq.shape[1]).view(1), 10) | |
output = [self.vocab.index2word[i.item()] for i in output] | |
output = [word for word in output if word not in ("<bos>", "<eos>", "<pad>")] | |
return " ".join(output) | |
if __name__ == "__main__": # Run as module | |
from .chat_dataset import ChatDataset | |
import torch.utils.data as data | |
CHAT_HISTORY_PATH = "models/seq2seq/data/train/chat_history.json" | |
batch_size = 20 | |
chat_dataset = ChatDataset(CHAT_HISTORY_PATH, max_message_count=10_000, batch_size=batch_size) | |
train_data = data.DataLoader(chat_dataset, batch_size=1, shuffle=True) | |
device = torch.device("cpu") | |
chatbot = Seq2SeqChatbot(500, chat_dataset.vocab.size, 2, 2, 0.1, device) | |
chatbot.load_checkpoint("models/seq2seq/checkpoint/150_checkpoint.tar") | |
chatbot.train_mode() | |
chatbot.train(3, train_data, 0.5, device, "./checkpoint/temp/", "frantics_fox", 50.0, 100) |