File size: 1,525 Bytes
74d4655
 
 
59c6d5c
74d4655
 
 
 
 
59c6d5c
 
74d4655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59c6d5c
74d4655
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
from typing import List, Dict
from .constants import PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN


class Vocab(nn.Module):
    def __init__(self, messages: List[Dict]):
        super().__init__()
        self.word2index: Dict[str, int] = {"<pad>": PAD_TOKEN, "<bos>": BOS_TOKEN, "<eos>": EOS_TOKEN, "<unk>": UNK_TOKEN}
        self.index2word: Dict[int, str] = {PAD_TOKEN: "<pad>", BOS_TOKEN: "<bos>", EOS_TOKEN: "<eos>", UNK_TOKEN: "<unk>"}
        self.word_count: Dict[str, int] = dict()
        self.size = 4

        for message in messages:
            self.add_sentence(message["text"])

        self.embedding = nn.Embedding(self.size, 300)

    def add_sentence(self, sentence):
        for word in sentence:
            self.add_word(word)

    def add_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.size
            self.index2word[self.size] = word
            self.word_count[word] = 1
            self.size += 1
        else:
            self.word_count[word] += 1

    def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
        indices = torch.LongTensor(len(sentence))
        for i, word in enumerate(sentence):
            indices[i] = self.word2index[word] if word in self.word2index else UNK_TOKEN
        return indices

    def forward(self, indices: torch.LongTensor):
        return self.embedding(indices)

    def __len__(self):
        return self.size