|
import torch |
|
import math |
|
import torch.nn as nn |
|
|
|
class Embeddings(nn.Module): |
|
''' |
|
The constructor for the embeddings class, initializing a look up table that corresponds each words in the vocabulary chain to a vector |
|
char: the amount of unique characters passed in |
|
dimension_for_model: the desired dimension of vector that's desired to pass the word to |
|
num_of_roles: the number of roles passed in |
|
''' |
|
def __init__ (self, char, dimension_for_model, num_of_roles = 2, max_turns = 16): |
|
|
|
super(Embeddings, self).__init__() |
|
|
|
self.lut = nn.Embedding(char, dimension_for_model) |
|
self.lut_roles = nn.Embedding (num_of_roles, dimension_for_model) |
|
self.lut_turns = nn.Embedding (max_turns, dimension_for_model) |
|
self.dimension_for_model = dimension_for_model |
|
self.norm = nn.LayerNorm(dimension_for_model) |
|
''' |
|
looks up the corresponding number from the look up table when numbers are passed in |
|
x: a tensor of token indices |
|
''' |
|
def forward(self, x, roles, turns): |
|
var = self.lut(x) |
|
var = var + self.lut(roles) |
|
var = var + self.lut(turns) |
|
|
|
|
|
var = var*math.sqrt(self.dimension_for_model) |
|
var = self.norm(var) |
|
return var |
|
|
|
if __name__ == '__main__': |
|
d_model = 512 |
|
|
|
|
|
characters = list("abcdefghijklmnopqrstuvwxyz ") |
|
|
|
|
|
char2idx = {char: idx for idx, char in enumerate(characters)} |
|
vocab = len(characters) |
|
|
|
|
|
look_up_table_roles = {'system': 0, 'user': 1} |
|
|
|
|
|
input_str = "01 system: hello world" |
|
|
|
|
|
position = int(input_str[0:2].strip()) |
|
input_str = input_str[2:] |
|
conversation = input_str.split(':')[1].strip() |
|
role = input_str.split(':')[0].strip() |
|
|
|
|
|
|
|
|
|
conversation_indices = [char2idx[char] for char in conversation if char in char2idx] |
|
position_indices = [position for char in conversation if char in char2idx] |
|
role_indices = [look_up_table_roles[role] for char in conversation if char in char2idx] |
|
|
|
|
|
|
|
conversations = torch.LongTensor([conversation_indices]) |
|
roles = torch.LongTensor([role_indices]) |
|
positions = torch.LongTensor([position_indices]) |
|
|
|
|
|
emb = Embeddings(vocab, d_model) |
|
embr = emb(conversations, roles, positions) |
|
|
|
print("embr:", embr) |
|
|
|
|