teepoat commited on
Commit
74d4655
·
unverified ·
1 Parent(s): 95368ba

Initial commit

Browse files
main.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Final
2
+ from telegram import Update
3
+ from telegram.ext import Application, MessageHandler, filters, ContextTypes
4
+ from typing import Optional
5
+ import random
6
+ import os
7
+ from dotenv import load_dotenv
8
+ from models.seq2seq.model import Seq2SeqChatbot
9
+ import torch
10
+
11
+
12
+ load_dotenv()
13
+
14
+ TOKEN: Final = os.environ.get("TOKEN")
15
+ BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
16
+ CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
17
+
18
+ CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
19
+
20
+ torch.manual_seed(0)
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ chatbot = Seq2SeqChatbot(500, 2, 2, 0.1, device)
24
+ chatbot.load_checkpoint(CHECKPOINT_PATH)
25
+ chatbot.eval_mode()
26
+
27
+ def handle_response(text: str) -> Optional[str]:
28
+ response_chance = 1.0
29
+ if random.random() < response_chance:
30
+ return chatbot(text)
31
+ return None
32
+
33
+
34
+ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
35
+ if update.message.chat_id == CHAT_ID:
36
+ text: str = update.message.text.replace(BOT_USERNAME, '').strip().lower()
37
+ response: Optional[str] = handle_response(text)
38
+ if response:
39
+ await context.bot.sendMessage(update.message.chat_id, response)
40
+
41
+
42
+ async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
43
+ print(f"{update.message.from_user.username} in {update.message.chat.type} "
44
+ f"chat caused error \"{context.error}\"\n"
45
+ f"{update}\"")
46
+
47
+ def main() -> None:
48
+ """Run the bot."""
49
+ application = Application.builder().token(TOKEN).build()
50
+
51
+ application.add_handler(MessageHandler(filters.TEXT, handle_message))
52
+ application.add_error_handler(error)
53
+
54
+ application.run_polling(allowed_updates=Update.ALL_TYPES)
55
+
56
+
57
+ if __name__ == '__main__':
58
+ print("Running main...")
59
+ # print(chatbot("test"))
60
+ main()
models/seq2seq/__init__.py ADDED
File without changes
models/seq2seq/attention.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from .custom_types import Method
5
+
6
+ class LuongAttention(nn.Module):
7
+ def __init__(self, method: Method, hidden_size: int):
8
+ super().__init__()
9
+ self.hidden_size = hidden_size
10
+ if not isinstance(method, Method):
11
+ raise ValueError(method, f"should be a member of `Method` enum")
12
+ match method:
13
+ case Method.DOT:
14
+ self.method = self.dot
15
+ case Method.GENERAL:
16
+ self.method = self.general
17
+ self.Wa = nn.Linear(hidden_size, hidden_size)
18
+ case Method.CONCAT:
19
+ self.method = self.concat
20
+ self.Wa = nn.Linear(hidden_size * 2, hidden_size)
21
+ self.Va = nn.Parameter(torch.FloatTensor(1, hidden_size))
22
+
23
+ def dot(self, hidden, encoder_outputs):
24
+ return torch.sum(hidden * encoder_outputs, dim=2)
25
+
26
+ def general(self, hidden, encoder_outputs):
27
+ return torch.sum(hidden * self.Wa(encoder_outputs), dim=2)
28
+
29
+ def concat(self, hidden, encoder_outputs):
30
+ hidden = hidden.permute(1, 0, 2)
31
+ energy = self.Wa(torch.cat((hidden.permute(1, 0, 2).expand(-1, encoder_outputs.size(1), -1), encoder_outputs), 2)).tanh()
32
+ return torch.sum(self.Va * energy, dim=2)
33
+
34
+ def forward(self, hidden, encoder_outputs):
35
+ attn_weights = self.method(hidden, encoder_outputs)
36
+ return F.softmax(attn_weights, dim=1).unsqueeze(1)
models/seq2seq/chat_dataset.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data as data
3
+ from typing import List, Union, Tuple
4
+ from collections import OrderedDict
5
+ from .vocab import Vocab
6
+ from .custom_types import Message, MessageId, Conversation
7
+ from torch.nn.utils.rnn import pad_sequence
8
+ from .custom_types import Token
9
+ import re
10
+ import json
11
+
12
+ class ChatDataset(data.Dataset):
13
+ def __init__(self, path: str, max_message_count: int = None, batch_size=5):
14
+ super().__init__()
15
+ self.path = path
16
+ self.batch_size = batch_size
17
+ self.messages: OrderedDict[MessageId, Message] = self.__load_messages_from_json(path, max_message_count)
18
+ self.conversations: List[Conversation] = ChatDataset.__conversations_from_messages(self.messages)
19
+ self.vocab = Vocab(list(self.messages.values())) # TODO: try changing this cast to something more applicable
20
+
21
+ self.batches_X, self.batches_y, self.lengths, self.mask = self.__batches_from_conversations()
22
+
23
+ self.length = len(self.batches_X)
24
+
25
+ def __batches_from_conversations(self) -> Tuple[List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor], List[torch.BoolTensor]]: # Shape of tensor in batch: (batch_size, max_len_in_batch)
26
+ conversations = sorted(self.conversations, key=lambda x: len(x[0])) # Sort by input sequence length
27
+ batches_X: List[torch.LongTensor] = list()
28
+ batches_y: List[torch.LongTensor] = list()
29
+ lengths: List[torch.LongTensor] = list()
30
+ mask: List[torch.BoolTensor] = list()
31
+ for i in range(0, len(conversations), self.batch_size):
32
+ batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
33
+ batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
34
+ lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
35
+ mask.append(batches_y[-1] != Token.PAD_TOKEN.value)
36
+ return batches_X, batches_y, lengths, mask
37
+
38
+ @classmethod
39
+ def __load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]:
40
+ messages: OrderedDict[MessageId, Message] = OrderedDict()
41
+ with open(path, "r", encoding="utf-8") as file:
42
+ chat_json = json.load(file)
43
+ for i, message in enumerate(chat_json["messages"]):
44
+ if max_message_count and i == max_message_count:
45
+ break
46
+ if message["type"] != "message":
47
+ continue
48
+ new_message = {
49
+ "id": message["id"],
50
+ "text": cls.__normalize(message["text"])
51
+ }
52
+ if not new_message["text"]: # Check for empty message
53
+ continue
54
+ if "reply_to_message_id" in message.keys():
55
+ new_message["reply_to_id"] = message["reply_to_message_id"]
56
+
57
+ messages[new_message["id"]] = new_message
58
+ return messages
59
+
60
+ @classmethod
61
+ def __conversations_from_messages(cls, messages: OrderedDict[MessageId, Message]) -> List[Conversation]:
62
+ # Search for message with `id` in the last `current_id` messages
63
+ def _get_message_by_id(current_id: int, id: int) -> Message:
64
+ for i in range(current_id - 1, -1, -1):
65
+ if messages[i]["id"] == id:
66
+ return messages[i]
67
+ return None
68
+
69
+ conversations: List[Conversation] = []
70
+
71
+ messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable
72
+ for i in range(len(messages) - 1): # There's no answer for last message so add -1
73
+ prev_message = messages_values[i]
74
+ if "reply_to_id" in messages_values[i].keys(): # Message is answer for message with `id` of `reply_to_id`
75
+ try:
76
+ prev_message = messages[messages_values[i]["reply_to_id"]]
77
+ except KeyError:
78
+ continue
79
+ conversations.append((prev_message["text"], messages_values[i+1]["text"]))
80
+ return conversations
81
+
82
+ @classmethod
83
+ def __normalize(cls, text: Union[str, List]) -> List[str]:
84
+ if isinstance(text, List):
85
+ text = " ".join([word for word in text if isinstance(word, str)])
86
+ text = text.lower().strip()
87
+ text = re.sub(r"([.!?])", r" \1 ", text)
88
+ text = re.sub(r"ё", r"е", text)
89
+ text = re.sub(r"[^а-яА-я.!?]+", r" ", text)
90
+ text = re.sub(r"\s+", r" ", text).strip()
91
+ return text.split()
92
+
93
+ def __getitem__(self, item):
94
+ return self.batches_X[item], self.batches_y[item], self.lengths[item], self.mask[item]
95
+
96
+ def __len__(self):
97
+ return self.length
models/seq2seq/custom_types.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, TypedDict, NotRequired, Tuple
2
+ from enum import Enum, auto
3
+
4
+ MessageId = int
5
+ MessageText = List[str]
6
+ Conversation = Tuple[MessageText]
7
+
8
+
9
+ class Message(TypedDict):
10
+ id: MessageId
11
+ text: MessageText
12
+ reply_to_id: NotRequired[int]
13
+
14
+ class Method(Enum):
15
+ DOT = auto()
16
+ GENERAL = auto()
17
+ CONCAT = auto()
18
+
19
+ class Token(Enum):
20
+ PAD_TOKEN = 0
21
+ BOS_TOKEN = 1
22
+ EOS_TOKEN = 2
23
+ UNK_TOKEN = 3
models/seq2seq/model.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
3
+ import torch.nn.functional as F
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from .chat_dataset import ChatDataset
7
+ from .attention import LuongAttention
8
+ from .custom_types import Method, Token
9
+ from .vocab import Vocab
10
+ from .searchers import GreedySearch
11
+ import os
12
+ import random
13
+ from tqdm import tqdm
14
+
15
+
16
+ class Seq2SeqEncoder(nn.Module):
17
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int, embedding: nn.Embedding):
18
+ super().__init__()
19
+ self.input_size = input_size
20
+ self.hidden_size = hidden_size
21
+ self.num_layers = num_layers
22
+
23
+ self.embedding = embedding
24
+ self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, bidirectional=True, batch_first=True) # batch_first is True, because I don't approve self-harm
25
+
26
+ def forward(self, x, lengths):
27
+ x = self.embedding(x) # Output shape: (batch_size, max_len_in_batch, hidden_size)
28
+ packed_embedded = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
29
+ outputs, hidden = self.rnn(packed_embedded)
30
+ outputs, _ = pad_packed_sequence(outputs, batch_first=True)
31
+ return outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:], hidden
32
+
33
+
34
+ class Seq2SeqDecoder(nn.Module):
35
+ def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, attn, embedding: nn.Embedding, dropout: int = 0.1):
36
+ super().__init__()
37
+ self.input_size = input_size
38
+ self.hidden_size = hidden_size
39
+ self.output_size = output_size
40
+ self.num_layers = num_layers
41
+
42
+ self.attn = attn
43
+ self.embedding = embedding
44
+ self.embedding_dropout = nn.Dropout(dropout)
45
+ self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True)
46
+
47
+ self.concat = nn.Linear(hidden_size * 2, hidden_size)
48
+ self.out = nn.Linear(hidden_size, output_size)
49
+
50
+ def forward(self, x, last_hidden, encoder_outputs):
51
+ embedded = self.embedding(x)
52
+ embedded = self.embedding_dropout(embedded)
53
+ decoder_outputs, hidden = self.rnn(embedded, last_hidden)
54
+ attn_weights = self.attn(decoder_outputs, encoder_outputs)
55
+
56
+ context = attn_weights.bmm(encoder_outputs).squeeze(1)
57
+
58
+ concat_input = torch.cat((decoder_outputs.squeeze(1), context), 1)
59
+ concat_output = torch.tanh(self.concat(concat_input))
60
+ output = self.out(concat_output)
61
+
62
+ output = F.softmax(output, dim=1)
63
+ return output, hidden
64
+
65
+
66
+ class Seq2SeqChatbot(nn.Module):
67
+ def __init__(self, hidden_size: int, vocab_size: int, encoder_num_layers: int, decoder_num_layers: int, decoder_embedding_dropout: float, device: torch.device):
68
+ super().__init__()
69
+ self.hidden_size = hidden_size
70
+ self.encoder_num_layers = encoder_num_layers
71
+ self.decoder_num_layers = decoder_num_layers
72
+ self.decoder_embedding_dropout = decoder_embedding_dropout
73
+ self.vocab_size = vocab_size
74
+ self.epoch = 0
75
+
76
+ self.device = device
77
+ self.vocab = Vocab([])
78
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
79
+ self.attn = LuongAttention(Method.DOT, hidden_size)
80
+ self.encoder = Seq2SeqEncoder(hidden_size, hidden_size, encoder_num_layers, self.embedding)
81
+ self.decoder = Seq2SeqDecoder(hidden_size, hidden_size, vocab_size, decoder_num_layers, self.attn, self.embedding, decoder_embedding_dropout)
82
+ self.encoder_optimizer = optim.Adam(self.encoder.parameters())
83
+ self.decoder_optimizer = optim.Adam(self.decoder.parameters())
84
+ self.searcher = GreedySearch(self.encoder, self.decoder, self.embedding, device)
85
+ self.to(device)
86
+ self.eval_mode()
87
+
88
+ def train(self, epochs, train_data, teacher_forcing_ratio, device, save_dir, model_name, clip, save_every):
89
+ def maskNLLLoss(inp, target, mask):
90
+ crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
91
+ loss = crossEntropy.masked_select(mask).mean()
92
+ loss = loss.to(device)
93
+ return loss
94
+
95
+ epoch_progress = tqdm(range(self.epoch, self.epoch + epochs), desc="Training", unit="epoch", leave=True)
96
+ epoch_progress.set_description(f"maskNLLLoss: None")
97
+
98
+ for epoch in epoch_progress:
99
+ for x_train, y_train, x_lengths, y_mask in train_data:
100
+ self.encoder_optimizer.zero_grad()
101
+ self.decoder_optimizer.zero_grad()
102
+ # Squeeze because batches are made in dataset and DataLoader is only for shuffling
103
+ x_train = x_train.squeeze(0).to(device)
104
+ y_train = y_train.squeeze(0).to(device)
105
+ x_lengths = x_lengths.squeeze(0) # Lengths are computed on CPU
106
+ y_mask = y_mask.squeeze(0).to(device)
107
+
108
+ encoder_outputs, hidden = self.encoder(x_train, x_lengths) # Output shape: (batch_size, max_len_in_batch, hidden_size)
109
+ hidden = hidden[:self.decoder_num_layers]
110
+ loss = 0
111
+ decoder_input = torch.LongTensor([[Token.BOS_TOKEN.value] for _ in range(y_train.shape[0])])
112
+ decoder_input = decoder_input.to(device)
113
+ use_teacher_forcing = random.random() < teacher_forcing_ratio
114
+ if use_teacher_forcing:
115
+ for t in range(y_train.shape[1]): # Process words in all batches for timestep t
116
+ decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs)
117
+ decoder_input = y_train[:, t].unsqueeze(1)
118
+ mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t])
119
+ loss += mask_loss
120
+ else:
121
+ for t in range(y_train.shape[1]):
122
+ decoder_outputs, hidden = self.decoder(decoder_input, hidden, encoder_outputs)
123
+ decoder_input = torch.argmax(decoder_outputs, dim=1).unsqueeze(1)
124
+ mask_loss = maskNLLLoss(decoder_outputs, y_train[:, t], y_mask[:, t])
125
+ loss += mask_loss
126
+
127
+ loss.backward()
128
+
129
+ _ = nn.utils.clip_grad_norm_(self.encoder.parameters(), clip)
130
+ _ = nn.utils.clip_grad_norm_(self.decoder.parameters(), clip)
131
+
132
+ self.encoder_optimizer.step()
133
+ self.decoder_optimizer.step()
134
+
135
+ if (epoch % save_every == 0 and epoch != 0) or epoch == save_every - 1:
136
+ directory = os.path.join(save_dir, model_name, '{}-{}'.format(self.encoder_num_layers, self.decoder_num_layers, self.hidden_size))
137
+ if not os.path.exists(directory):
138
+ os.makedirs(directory)
139
+ torch.save({
140
+ 'epoch': epoch + self.epoch,
141
+ 'en': self.encoder.state_dict(),
142
+ 'de': self.decoder.state_dict(),
143
+ 'en_opt': self.encoder_optimizer.state_dict(),
144
+ 'de_opt': self.decoder_optimizer.state_dict(),
145
+ 'loss': loss,
146
+ 'voc_dict': self.vocab.__dict__,
147
+ 'embedding': self.embedding.state_dict()
148
+ }, os.path.join(directory, '{}_{}.tar'.format(epoch, 'checkpoint')))
149
+
150
+ epoch_progress.set_description(f"maskNLLLoss: {loss:.8f}")
151
+
152
+ def to(self, device):
153
+ self.encoder = self.encoder.to(device)
154
+ self.decoder = self.decoder.to(device)
155
+ self.embedding = self.embedding.to(device)
156
+ self.attn = self.attn.to(device)
157
+
158
+ def train_mode(self):
159
+ self.encoder.train()
160
+ self.decoder.train()
161
+ self.embedding.train()
162
+ self.attn.train()
163
+
164
+ def eval_mode(self):
165
+ self.encoder.eval()
166
+ self.decoder.eval()
167
+ self.embedding.eval()
168
+ self.attn.eval()
169
+
170
+ def load_checkpoint(self, checkpoint_path: str):
171
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
172
+ encoder_sd = checkpoint["en"]
173
+ decoder_sd = checkpoint["de"]
174
+ embedding_sd = checkpoint["embedding"]
175
+ self.vocab.__dict__ = checkpoint["voc_dict"]
176
+ encoder_optimizer_sd = checkpoint["en_opt"]
177
+ decoder_optimizer_sd = checkpoint["de_opt"]
178
+ self.epoch = checkpoint["epoch"]
179
+
180
+ self.encoder_optimizer.load_state_dict(encoder_optimizer_sd)
181
+ self.decoder_optimizer.load_state_dict(decoder_optimizer_sd)
182
+ self.embedding.load_state_dict(embedding_sd)
183
+ self.encoder.load_state_dict(encoder_sd)
184
+ self.decoder.load_state_dict(decoder_sd)
185
+
186
+ def forward(self, input_seq: str):
187
+ input_seq = ChatDataset._ChatDataset__normalize(input_seq)
188
+ input_seq = self.vocab.sentence_indices(input_seq + ["<eos>"]).unsqueeze(0).to(self.device)
189
+ output, _ = self.searcher(input_seq, torch.tensor(input_seq.shape[1]).view(1), 10)
190
+ output = [self.vocab.index2word[i.item()] for i in output]
191
+ output = [word for word in output if word not in ("<bos>", "<eos>", "<pad>")]
192
+ return " ".join(output)
193
+
194
+
195
+ if __name__ == "__main__": # Run as module
196
+ from .chat_dataset import ChatDataset
197
+ import torch.utils.data as data
198
+
199
+ CHAT_HISTORY_PATH = "models/seq2seq/data/train/chat_history.json"
200
+ batch_size = 20
201
+ chat_dataset = ChatDataset(CHAT_HISTORY_PATH, max_message_count=10_000, batch_size=batch_size)
202
+ train_data = data.DataLoader(chat_dataset, batch_size=1, shuffle=True)
203
+
204
+ device = torch.device("cpu")
205
+ chatbot = Seq2SeqChatbot(500, chat_dataset.vocab.size, 2, 2, 0.1, device)
206
+ chatbot.load_checkpoint("models/seq2seq/checkpoint/150_checkpoint.tar")
207
+ chatbot.train_mode()
208
+ chatbot.train(3, train_data, 0.5, device, "./checkpoint/temp/", "frantics_fox", 50.0, 100)
models/seq2seq/searchers.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .custom_types import Token
4
+
5
+ class GreedySearch(nn.Module):
6
+ def __init__(self, encoder, decoder, embedding, device):
7
+ super().__init__()
8
+ self.encoder = encoder
9
+ self.decoder = decoder
10
+ self.embedding = embedding
11
+ self.device = device
12
+
13
+ def forward(self, x, input_length, max_length):
14
+ encoder_outputs, hidden = self.encoder(x, input_length)
15
+ decoder_hidden = hidden[:self.decoder.num_layers]
16
+ decoder_input = torch.ones(1, 1, device=self.device, dtype=torch.long) * Token.BOS_TOKEN.value
17
+ all_tokens = torch.zeros([0], device=self.device, dtype=torch.long)
18
+ all_scores = torch.zeros([0], device=self.device)
19
+
20
+ for _ in range(max_length):
21
+ decoder_outputs, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
22
+ decoder_scores, decoder_input = torch.max(decoder_outputs, dim=1)
23
+ all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
24
+ all_scores = torch.cat((all_scores, decoder_scores), dim=0)
25
+ decoder_input.unsqueeze_(0)
26
+
27
+ return all_tokens, all_scores
models/seq2seq/vocab.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import List, Dict
4
+ from .custom_types import Token
5
+
6
+
7
+ class Vocab(nn.Module):
8
+ def __init__(self, messages: List[Dict]):
9
+ super().__init__()
10
+ self.word2index: Dict[str, int] = {"<pad>": Token.PAD_TOKEN.value, "<bos>": Token.BOS_TOKEN.value, "<eos>": Token.EOS_TOKEN.value, "<unk>": Token.UNK_TOKEN.value}
11
+ self.index2word: Dict[int, str] = {Token.PAD_TOKEN.value: "<pad>", Token.BOS_TOKEN.value: "<bos>", Token.EOS_TOKEN.value: "<eos>", Token.UNK_TOKEN.value: "<unk>"}
12
+ self.word_count: Dict[str, int] = dict()
13
+ self.size = 4
14
+
15
+ for message in messages:
16
+ self.add_sentence(message["text"])
17
+
18
+ self.embedding = nn.Embedding(self.size, 300)
19
+
20
+ def add_sentence(self, sentence):
21
+ for word in sentence:
22
+ self.add_word(word)
23
+
24
+ def add_word(self, word):
25
+ if word not in self.word2index:
26
+ self.word2index[word] = self.size
27
+ self.index2word[self.size] = word
28
+ self.word_count[word] = 1
29
+ self.size += 1
30
+ else:
31
+ self.word_count[word] += 1
32
+
33
+ def sentence_indices(self, sentence: List[str]) -> torch.LongTensor:
34
+ indices = torch.LongTensor(len(sentence))
35
+ for i, word in enumerate(sentence):
36
+ indices[i] = self.word2index[word] if word in self.word2index else Token.UNK_TOKEN.value
37
+ return indices
38
+
39
+ def forward(self, indices: torch.LongTensor):
40
+ return self.embedding(indices)
41
+
42
+ def __len__(self):
43
+ return self.size