DHRUV SHEKHAWAT
commited on
Commit
·
1dd09ef
1
Parent(s):
52f9f0f
Upload 2 files
Browse files
models.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Embeddings(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Implements embeddings of the words and adds their positional encodings.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, vocab_size, d_model, max_len = 50):
|
| 14 |
+
super(Embeddings, self).__init__()
|
| 15 |
+
self.d_model = d_model
|
| 16 |
+
self.dropout = nn.Dropout(0.1)
|
| 17 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 18 |
+
self.pe = self.create_positinal_encoding(max_len, self.d_model)
|
| 19 |
+
self.dropout = nn.Dropout(0.1)
|
| 20 |
+
|
| 21 |
+
def create_positinal_encoding(self, max_len, d_model):
|
| 22 |
+
pe = torch.zeros(max_len, d_model).to(device)
|
| 23 |
+
for pos in range(max_len): # for each position of the word
|
| 24 |
+
for i in range(0, d_model, 2): # for each dimension of the each position
|
| 25 |
+
pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
|
| 26 |
+
pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
|
| 27 |
+
pe = pe.unsqueeze(0) # include the batch size
|
| 28 |
+
return pe
|
| 29 |
+
|
| 30 |
+
def forward(self, encoded_words):
|
| 31 |
+
embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
|
| 32 |
+
embedding += self.pe[:, :embedding.size(1)] # pe will automatically be expanded with the same batch size as encoded_words
|
| 33 |
+
embedding = self.dropout(embedding)
|
| 34 |
+
return embedding
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MultiHeadAttention(nn.Module):
|
| 39 |
+
|
| 40 |
+
def __init__(self, heads, d_model):
|
| 41 |
+
|
| 42 |
+
super(MultiHeadAttention, self).__init__()
|
| 43 |
+
assert d_model % heads == 0
|
| 44 |
+
self.d_k = d_model // heads
|
| 45 |
+
self.heads = heads
|
| 46 |
+
self.dropout = nn.Dropout(0.1)
|
| 47 |
+
self.query = nn.Linear(d_model, d_model)
|
| 48 |
+
self.key = nn.Linear(d_model, d_model)
|
| 49 |
+
self.value = nn.Linear(d_model, d_model)
|
| 50 |
+
self.concat = nn.Linear(d_model, d_model)
|
| 51 |
+
|
| 52 |
+
def forward(self, query, key, value, mask):
|
| 53 |
+
"""
|
| 54 |
+
query, key, value of shape: (batch_size, max_len, 512)
|
| 55 |
+
mask of shape: (batch_size, 1, 1, max_words)
|
| 56 |
+
"""
|
| 57 |
+
# (batch_size, max_len, 512)
|
| 58 |
+
query = self.query(query)
|
| 59 |
+
key = self.key(key)
|
| 60 |
+
value = self.value(value)
|
| 61 |
+
|
| 62 |
+
# (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
|
| 63 |
+
query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
|
| 64 |
+
key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
|
| 65 |
+
value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
|
| 66 |
+
|
| 67 |
+
# (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
|
| 68 |
+
scores = torch.matmul(query, key.permute(0,1,3,2)) / math.sqrt(query.size(-1))
|
| 69 |
+
scores = scores.masked_fill(mask == 0, -1e9) # (batch_size, h, max_len, max_len)
|
| 70 |
+
weights = F.softmax(scores, dim = -1) # (batch_size, h, max_len, max_len)
|
| 71 |
+
weights = self.dropout(weights)
|
| 72 |
+
# (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
|
| 73 |
+
context = torch.matmul(weights, value)
|
| 74 |
+
# (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, h * d_k)
|
| 75 |
+
context = context.permute(0,2,1,3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
|
| 76 |
+
# (batch_size, max_len, h * d_k)
|
| 77 |
+
interacted = self.concat(context)
|
| 78 |
+
return interacted
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class FeedForward(nn.Module):
|
| 83 |
+
|
| 84 |
+
def __init__(self, d_model, middle_dim = 2048):
|
| 85 |
+
super(FeedForward, self).__init__()
|
| 86 |
+
|
| 87 |
+
self.fc1 = nn.Linear(d_model, middle_dim)
|
| 88 |
+
self.fc2 = nn.Linear(middle_dim, d_model)
|
| 89 |
+
self.dropout = nn.Dropout(0.1)
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
out = F.relu(self.fc1(x))
|
| 93 |
+
out = self.fc2(self.dropout(out))
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class EncoderLayer(nn.Module):
|
| 98 |
+
|
| 99 |
+
def __init__(self, d_model, heads):
|
| 100 |
+
super(EncoderLayer, self).__init__()
|
| 101 |
+
self.layernorm = nn.LayerNorm(d_model)
|
| 102 |
+
self.self_multihead = MultiHeadAttention(heads, d_model)
|
| 103 |
+
self.feed_forward = FeedForward(d_model)
|
| 104 |
+
self.dropout = nn.Dropout(0.1)
|
| 105 |
+
|
| 106 |
+
def forward(self, embeddings, mask):
|
| 107 |
+
interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
|
| 108 |
+
interacted = self.layernorm(interacted + embeddings)
|
| 109 |
+
feed_forward_out = self.dropout(self.feed_forward(interacted))
|
| 110 |
+
encoded = self.layernorm(feed_forward_out + interacted)
|
| 111 |
+
return encoded
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class DecoderLayer(nn.Module):
|
| 115 |
+
|
| 116 |
+
def __init__(self, d_model, heads):
|
| 117 |
+
super(DecoderLayer, self).__init__()
|
| 118 |
+
self.layernorm = nn.LayerNorm(d_model)
|
| 119 |
+
self.self_multihead = MultiHeadAttention(heads, d_model)
|
| 120 |
+
self.src_multihead = MultiHeadAttention(heads, d_model)
|
| 121 |
+
self.feed_forward = FeedForward(d_model)
|
| 122 |
+
self.dropout = nn.Dropout(0.1)
|
| 123 |
+
|
| 124 |
+
def forward(self, embeddings, encoded, src_mask, target_mask):
|
| 125 |
+
query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
|
| 126 |
+
query = self.layernorm(query + embeddings)
|
| 127 |
+
interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
|
| 128 |
+
interacted = self.layernorm(interacted + query)
|
| 129 |
+
feed_forward_out = self.dropout(self.feed_forward(interacted))
|
| 130 |
+
decoded = self.layernorm(feed_forward_out + interacted)
|
| 131 |
+
return decoded
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Transformer(nn.Module):
|
| 135 |
+
|
| 136 |
+
def __init__(self, d_model, heads, num_layers, word_map):
|
| 137 |
+
super(Transformer, self).__init__()
|
| 138 |
+
|
| 139 |
+
self.d_model = d_model
|
| 140 |
+
self.vocab_size = len(word_map)
|
| 141 |
+
self.embed = Embeddings(self.vocab_size, d_model)
|
| 142 |
+
self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
|
| 143 |
+
self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
|
| 144 |
+
self.logit = nn.Linear(d_model, self.vocab_size)
|
| 145 |
+
|
| 146 |
+
def encode(self, src_words, src_mask):
|
| 147 |
+
src_embeddings = self.embed(src_words)
|
| 148 |
+
for layer in self.encoder:
|
| 149 |
+
src_embeddings = layer(src_embeddings, src_mask)
|
| 150 |
+
return src_embeddings
|
| 151 |
+
|
| 152 |
+
def decode(self, target_words, target_mask, src_embeddings, src_mask):
|
| 153 |
+
tgt_embeddings = self.embed(target_words)
|
| 154 |
+
for layer in self.decoder:
|
| 155 |
+
tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
|
| 156 |
+
return tgt_embeddings
|
| 157 |
+
|
| 158 |
+
def forward(self, src_words, src_mask, target_words, target_mask):
|
| 159 |
+
encoded = self.encode(src_words, src_mask)
|
| 160 |
+
decoded = self.decode(target_words, target_mask, encoded, src_mask)
|
| 161 |
+
out = F.log_softmax(self.logit(decoded), dim = 2)
|
| 162 |
+
return out
|
utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import torch.utils.data
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
+
|
| 9 |
+
class Dataset(Dataset):
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
|
| 13 |
+
self.pairs = json.load(open('pairs_encoded.json'))
|
| 14 |
+
self.dataset_size = len(self.pairs)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, i):
|
| 17 |
+
|
| 18 |
+
question = torch.LongTensor(self.pairs[i][0])
|
| 19 |
+
reply = torch.LongTensor(self.pairs[i][1])
|
| 20 |
+
|
| 21 |
+
return question, reply
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
return self.dataset_size
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_masks(question, reply_input, reply_target):
|
| 28 |
+
|
| 29 |
+
def subsequent_mask(size):
|
| 30 |
+
mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
|
| 31 |
+
return mask.unsqueeze(0)
|
| 32 |
+
|
| 33 |
+
question_mask = (question!=0).to(device)
|
| 34 |
+
question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words)
|
| 35 |
+
|
| 36 |
+
reply_input_mask = reply_input!=0
|
| 37 |
+
reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words)
|
| 38 |
+
reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
|
| 39 |
+
reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
|
| 40 |
+
reply_target_mask = reply_target!=0 # (batch_size, max_words)
|
| 41 |
+
|
| 42 |
+
return question_mask, reply_input_mask, reply_target_mask
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AdamWarmup:
|
| 46 |
+
|
| 47 |
+
def __init__(self, model_size, warmup_steps, optimizer):
|
| 48 |
+
|
| 49 |
+
self.model_size = model_size
|
| 50 |
+
self.warmup_steps = warmup_steps
|
| 51 |
+
self.optimizer = optimizer
|
| 52 |
+
self.current_step = 0
|
| 53 |
+
self.lr = 0
|
| 54 |
+
|
| 55 |
+
def get_lr(self):
|
| 56 |
+
return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
|
| 57 |
+
|
| 58 |
+
def step(self):
|
| 59 |
+
# Increment the number of steps each time we call the step function
|
| 60 |
+
self.current_step += 1
|
| 61 |
+
lr = self.get_lr()
|
| 62 |
+
for param_group in self.optimizer.param_groups:
|
| 63 |
+
param_group['lr'] = lr
|
| 64 |
+
# update the learning rate
|
| 65 |
+
self.lr = lr
|
| 66 |
+
self.optimizer.step()
|
| 67 |
+
|
| 68 |
+
class LossWithLS(nn.Module):
|
| 69 |
+
|
| 70 |
+
def __init__(self, size, smooth):
|
| 71 |
+
super(LossWithLS, self).__init__()
|
| 72 |
+
self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
|
| 73 |
+
self.confidence = 1.0 - smooth
|
| 74 |
+
self.smooth = smooth
|
| 75 |
+
self.size = size
|
| 76 |
+
|
| 77 |
+
def forward(self, prediction, target, mask):
|
| 78 |
+
"""
|
| 79 |
+
prediction of shape: (batch_size, max_words, vocab_size)
|
| 80 |
+
target and mask of shape: (batch_size, max_words)
|
| 81 |
+
"""
|
| 82 |
+
prediction = prediction.view(-1, prediction.size(-1)) # (batch_size * max_words, vocab_size)
|
| 83 |
+
target = target.contiguous().view(-1) # (batch_size * max_words)
|
| 84 |
+
mask = mask.float()
|
| 85 |
+
mask = mask.view(-1) # (batch_size * max_words)
|
| 86 |
+
labels = prediction.data.clone()
|
| 87 |
+
labels.fill_(self.smooth / (self.size - 1))
|
| 88 |
+
labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
|
| 89 |
+
loss = self.criterion(prediction, labels) # (batch_size * max_words, vocab_size)
|
| 90 |
+
loss = (loss.sum(1) * mask).sum() / mask.sum()
|
| 91 |
+
return loss
|