import torch import torch.utils.data as data from typing import List, Union, Tuple from collections import OrderedDict from .vocab import Vocab from .custom_types import Message, MessageId, Conversation from torch.nn.utils.rnn import pad_sequence from .constants import PAD_TOKEN import re import json class ChatDataset(data.Dataset): def __init__(self, path: str, max_message_count: int = None, batch_size=5): super().__init__() self.path = path self.batch_size = batch_size self.messages: OrderedDict[MessageId, Message] = self.__load_messages_from_json(path, max_message_count) self.conversations: List[Conversation] = ChatDataset.__conversations_from_messages(self.messages) self.vocab = Vocab(list(self.messages.values())) # TODO: try changing this cast to something more applicable self.batches_X, self.batches_y, self.lengths, self.mask = self.__batches_from_conversations() self.length = len(self.batches_X) def __batches_from_conversations(self) -> Tuple[List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor], List[torch.BoolTensor]]: # Shape of tensor in batch: (batch_size, max_len_in_batch) conversations = sorted(self.conversations, key=lambda x: len(x[0])) # Sort by input sequence length batches_X: List[torch.LongTensor] = list() batches_y: List[torch.LongTensor] = list() lengths: List[torch.LongTensor] = list() mask: List[torch.BoolTensor] = list() for i in range(0, len(conversations), self.batch_size): batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + [""]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0)) batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + [""]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0)) lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)])) mask.append(batches_y[-1] != PAD_TOKEN) return batches_X, batches_y, lengths, mask @classmethod def __load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]: messages: OrderedDict[MessageId, Message] = OrderedDict() with open(path, "r", encoding="utf-8") as file: chat_json = json.load(file) for i, message in enumerate(chat_json["messages"]): if max_message_count and i == max_message_count: break if message["type"] != "message": continue new_message = { "id": message["id"], "text": cls.__normalize(message["text"]) } if not new_message["text"]: # Check for empty message continue if "reply_to_message_id" in message.keys(): new_message["reply_to_id"] = message["reply_to_message_id"] messages[new_message["id"]] = new_message return messages @classmethod def __conversations_from_messages(cls, messages: OrderedDict[MessageId, Message]) -> List[Conversation]: # Search for message with `id` in the last `current_id` messages def _get_message_by_id(current_id: int, id: int) -> Message: for i in range(current_id - 1, -1, -1): if messages[i]["id"] == id: return messages[i] return None conversations: List[Conversation] = [] messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable for i in range(len(messages) - 1): # There's no answer for last message so add -1 prev_message = messages_values[i] if "reply_to_id" in messages_values[i].keys(): # Message is answer for message with `id` of `reply_to_id` try: prev_message = messages[messages_values[i]["reply_to_id"]] except KeyError: continue conversations.append((prev_message["text"], messages_values[i+1]["text"])) return conversations @classmethod def __normalize(cls, text: Union[str, List]) -> List[str]: if isinstance(text, List): text = " ".join([word for word in text if isinstance(word, str)]) text = text.lower().strip() text = re.sub(r"([.!?])", r" \1 ", text) text = re.sub(r"ё", r"е", text) text = re.sub(r"[^а-яА-я.!?]+", r" ", text) text = re.sub(r"\s+", r" ", text).strip() return text.split() def __getitem__(self, item): return self.batches_X[item], self.batches_y[item], self.lengths[item], self.mask[item] def __len__(self): return self.length