frantics-bot / models /seq2seq /chat_dataset.py
WonderWaffle4
GPT-2 bot with transformers and bot server
59c6d5c
import torch
import torch.utils.data as data
from typing import List, Union, Tuple
from collections import OrderedDict
from .vocab import Vocab
from .custom_types import Message, MessageId, Conversation
from torch.nn.utils.rnn import pad_sequence
from .constants import PAD_TOKEN
import re
import json
class ChatDataset(data.Dataset):
def __init__(self, path: str, max_message_count: int = None, batch_size=5):
super().__init__()
self.path = path
self.batch_size = batch_size
self.messages: OrderedDict[MessageId, Message] = self.__load_messages_from_json(path, max_message_count)
self.conversations: List[Conversation] = ChatDataset.__conversations_from_messages(self.messages)
self.vocab = Vocab(list(self.messages.values())) # TODO: try changing this cast to something more applicable
self.batches_X, self.batches_y, self.lengths, self.mask = self.__batches_from_conversations()
self.length = len(self.batches_X)
def __batches_from_conversations(self) -> Tuple[List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor], List[torch.BoolTensor]]: # Shape of tensor in batch: (batch_size, max_len_in_batch)
conversations = sorted(self.conversations, key=lambda x: len(x[0])) # Sort by input sequence length
batches_X: List[torch.LongTensor] = list()
batches_y: List[torch.LongTensor] = list()
lengths: List[torch.LongTensor] = list()
mask: List[torch.BoolTensor] = list()
for i in range(0, len(conversations), self.batch_size):
batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
mask.append(batches_y[-1] != PAD_TOKEN)
return batches_X, batches_y, lengths, mask
@classmethod
def __load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]:
messages: OrderedDict[MessageId, Message] = OrderedDict()
with open(path, "r", encoding="utf-8") as file:
chat_json = json.load(file)
for i, message in enumerate(chat_json["messages"]):
if max_message_count and i == max_message_count:
break
if message["type"] != "message":
continue
new_message = {
"id": message["id"],
"text": cls.__normalize(message["text"])
}
if not new_message["text"]: # Check for empty message
continue
if "reply_to_message_id" in message.keys():
new_message["reply_to_id"] = message["reply_to_message_id"]
messages[new_message["id"]] = new_message
return messages
@classmethod
def __conversations_from_messages(cls, messages: OrderedDict[MessageId, Message]) -> List[Conversation]:
# Search for message with `id` in the last `current_id` messages
def _get_message_by_id(current_id: int, id: int) -> Message:
for i in range(current_id - 1, -1, -1):
if messages[i]["id"] == id:
return messages[i]
return None
conversations: List[Conversation] = []
messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable
for i in range(len(messages) - 1): # There's no answer for last message so add -1
prev_message = messages_values[i]
if "reply_to_id" in messages_values[i].keys(): # Message is answer for message with `id` of `reply_to_id`
try:
prev_message = messages[messages_values[i]["reply_to_id"]]
except KeyError:
continue
conversations.append((prev_message["text"], messages_values[i+1]["text"]))
return conversations
@classmethod
def __normalize(cls, text: Union[str, List]) -> List[str]:
if isinstance(text, List):
text = " ".join([word for word in text if isinstance(word, str)])
text = text.lower().strip()
text = re.sub(r"([.!?])", r" \1 ", text)
text = re.sub(r"ё", r"е", text)
text = re.sub(r"[^а-яА-я.!?]+", r" ", text)
text = re.sub(r"\s+", r" ", text).strip()
return text.split()
def __getitem__(self, item):
return self.batches_X[item], self.batches_y[item], self.lengths[item], self.mask[item]
def __len__(self):
return self.length