File size: 5,062 Bytes
74d4655
 
 
 
 
 
 
59c6d5c
74d4655
 
 
59c6d5c
74d4655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59c6d5c
74d4655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.utils.data as data
from typing import List, Union, Tuple
from collections import OrderedDict
from .vocab import Vocab
from .custom_types import Message, MessageId, Conversation
from torch.nn.utils.rnn import pad_sequence
from .constants import PAD_TOKEN
import re
import json


class ChatDataset(data.Dataset):
    def __init__(self, path: str, max_message_count: int = None, batch_size=5):
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.messages: OrderedDict[MessageId, Message] = self.__load_messages_from_json(path, max_message_count)
        self.conversations: List[Conversation] = ChatDataset.__conversations_from_messages(self.messages)
        self.vocab = Vocab(list(self.messages.values())) # TODO: try changing this cast to something more applicable

        self.batches_X, self.batches_y, self.lengths, self.mask = self.__batches_from_conversations()

        self.length = len(self.batches_X)

    def __batches_from_conversations(self) -> Tuple[List[torch.LongTensor], List[torch.LongTensor], List[torch.LongTensor], List[torch.BoolTensor]]: # Shape of tensor in batch: (batch_size, max_len_in_batch)
        conversations = sorted(self.conversations, key=lambda x: len(x[0])) # Sort by input sequence length
        batches_X: List[torch.LongTensor] = list()
        batches_y: List[torch.LongTensor] = list()
        lengths: List[torch.LongTensor] = list()
        mask: List[torch.BoolTensor] = list()
        for i in range(0, len(conversations), self.batch_size):
            batches_X.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][0] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
            batches_y.append(pad_sequence([self.vocab.sentence_indices(conversations[i+j][1] + ["<eos>"]) for j in range(self.batch_size) if i+j < len(conversations)], batch_first=True, padding_value=0))
            lengths.append(torch.tensor([len(conversations[i+j][0]) for j in range(self.batch_size) if i+j < len(conversations)]))
            mask.append(batches_y[-1] != PAD_TOKEN)
        return batches_X, batches_y, lengths, mask

    @classmethod
    def __load_messages_from_json(cls, path: str, max_message_count: int = None) -> OrderedDict[MessageId, Message]:
        messages: OrderedDict[MessageId, Message] = OrderedDict()
        with open(path, "r", encoding="utf-8") as file:
            chat_json = json.load(file)
            for i, message in enumerate(chat_json["messages"]):
                if max_message_count and i == max_message_count:
                    break
                if message["type"] != "message":
                    continue
                new_message = {
                    "id": message["id"],
                    "text": cls.__normalize(message["text"])
                }
                if not new_message["text"]: # Check for empty message
                    continue
                if "reply_to_message_id" in message.keys():
                    new_message["reply_to_id"] = message["reply_to_message_id"]

                messages[new_message["id"]] = new_message
        return messages

    @classmethod
    def __conversations_from_messages(cls, messages: OrderedDict[MessageId, Message]) -> List[Conversation]:
        # Search for message with `id` in the last `current_id` messages
        def _get_message_by_id(current_id: int, id: int) -> Message:
            for i in range(current_id - 1, -1, -1):
                if messages[i]["id"] == id:
                    return messages[i]
            return None

        conversations: List[Conversation] = []

        messages_values = list(messages.values()) # TODO: try changing this cast to something more applicable
        for i in range(len(messages) - 1): # There's no answer for last message so add -1
            prev_message = messages_values[i]
            if "reply_to_id" in messages_values[i].keys(): # Message is answer for message with `id` of `reply_to_id`
                try:
                    prev_message = messages[messages_values[i]["reply_to_id"]]
                except KeyError:
                    continue
            conversations.append((prev_message["text"], messages_values[i+1]["text"]))
        return conversations

    @classmethod
    def __normalize(cls, text: Union[str, List]) -> List[str]:
        if isinstance(text, List):
            text = " ".join([word for word in text if isinstance(word, str)])
        text = text.lower().strip()
        text = re.sub(r"([.!?])", r" \1 ", text)
        text = re.sub(r"ё", r"е", text)
        text = re.sub(r"[^а-яА-я.!?]+", r" ", text)
        text = re.sub(r"\s+", r" ", text).strip()
        return text.split()

    def __getitem__(self, item):
        return self.batches_X[item], self.batches_y[item], self.lengths[item], self.mask[item]

    def __len__(self):
        return self.length