WonderWaffle4
GPT-2 bot with transformers and bot server
59c6d5c
import torch
import torch.nn as nn
from typing import List, Dict
from .constants import PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN
class Vocab(nn.Module):
def __init__(self, messages: List[Dict]):
super().__init__()
self.word2index: Dict[str, int] = {"<pad>": PAD_TOKEN, "<bos>": BOS_TOKEN, "<eos>": EOS_TOKEN, "<unk>": UNK_TOKEN}
self.index2word: Dict[int, str] = {PAD_TOKEN: "<pad>", BOS_TOKEN: "<bos>", EOS_TOKEN: "<eos>", UNK_TOKEN: "<unk>"}
self.word_count: Dict[str, int] = dict()
self.size = 4
for message in messages:
self.add_sentence(message["text"])
self.embedding = nn.Embedding(self.size, 300)
def add_sentence(self, sentence):
for word in sentence:
self.add_word(word)
def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.size
self.index2word[self.size] = word
self.word_count[word] = 1
self.size += 1
else:
self.word_count[word] += 1
def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
indices = torch.LongTensor(len(sentence))
for i, word in enumerate(sentence):
indices[i] = self.word2index[word] if word in self.word2index else UNK_TOKEN
return indices
def forward(self, indices: torch.LongTensor):
return self.embedding(indices)
def __len__(self):
return self.size