File size: 8,979 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict

from bytelatent.data.data_types import Batch, BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.sampling_iterator import SamplingIteratorState


class PackingArgs(BaseModel):
    model_config = ConfigDict(extra="forbid")
    batch_size: int
    seq_len: int
    pad_id: int
    max_length: int | None
    pad_to_max_length: bool
    enable_byte_ngrams: bool


class PackingIteratorState(BaseModel, IteratorState):
    model_config = ConfigDict(extra="forbid")
    sequence_iterator_state: SamplingIteratorState
    packing_args: PackingArgs

    def build(self) -> "PackingIterator":
        return PackingIterator(
            sequence_iterator=self.sequence_iterator_state.build(),
            packing_args=self.packing_args,
        )


def _merge_patch_seq_masks(bs, slen: int, mask_seqs: list[list[bool]]):
    assert len(mask_seqs) == bs
    lens = [len(m) for m in mask_seqs]
    if all(all(m) for m in mask_seqs) and all(lens[0] == l for l in lens):
        return None
    assert slen == max(lens) - 1
    mask = np.zeros((bs, slen), dtype=bool)
    for i, m in enumerate(mask_seqs):
        if m is None:
            print(
                "Did not implement None mask, the mask should be True for all toks, so we need to pass that to this function."
            )
            raise NotImplementedError
        mask[i][: len(mask_seqs[i]) - 1] = mask_seqs[i][1:]
    return mask


def truncate_batch(
    batch: Batch,
    max_length: int,
    pad_id: int,
    pad_to_max_length: bool = False,
    *,
    enable_byte_ngrams: bool,
):
    """
    Truncate the x to a given size, making sure we remove the corresponding patch sizes in patch_lenghts
    and fixing the batch.mask.

    batch.patch_lengths has unchanged shape
    x,y, and mask may reduce in size
    """
    if batch.patch_lengths is None:
        return batch

    seq_lengths = batch.patch_lengths.sum(axis=1)
    max_length_adj = max_length + 1
    if np.any(seq_lengths > max_length_adj):
        for i in range(batch.x.shape[0]):
            if seq_lengths[i] > max_length_adj:
                # Find id of patch that tips over max_length + 1
                count, j = 0, 0
                while count + batch.patch_lengths[i, j] <= max_length_adj:
                    count += batch.patch_lengths[i, j]
                    j += 1
                # Edit the batch
                assert j < batch.patch_lengths.shape[1]
                batch.x[i, max_length:] = pad_id
                batch.y[i, max_length:] = pad_id
                if batch.mask is not None:
                    batch.mask[i, max_length:] = False
                batch.patch_lengths[i, j:] = 0
                batch.patch_lengths[i, j] = max_length_adj - count

        # Truncate if necessary.
        if max_length < batch.x.shape[1]:
            batch.x = batch.x[:, :max_length]
            batch.y = batch.y[:, :max_length]
            if batch.mask is not None:
                batch.mask = batch.mask[:, :max_length]

    # Right pad to max_length if necessary
    elif pad_to_max_length:
        if batch.x.shape[1] < max_length:
            # NOTE: this has to be done on an actual patch.
            non_zero_indices = (batch.patch_lengths != 0).sum(axis=1) - 1
            non_zero_indices = np.maximum(0, non_zero_indices)
            batch.patch_lengths[range(len(batch.patch_lengths)), non_zero_indices] += (
                max_length - batch.x.shape[1]
            )
            # TODO: We could get rid of many of these complications by moving this funciton directly in the dataloader.
            x = np.full((batch.x.shape[0], max_length), pad_id, dtype=batch.x.dtype)
            x[:, : batch.x.shape[1]] = batch.x
            batch.x = x
        if batch.y.shape[1] < max_length:
            y = np.full((batch.y.shape[0], max_length), pad_id, dtype=batch.y.dtype)
            y[:, : batch.y.shape[1]] = batch.y
            batch.y = y
        if batch.mask is not None and batch.mask.shape[1] < max_length:
            mask = np.full(
                (batch.mask.shape[0], max_length), False, dtype=batch.mask.dtype
            )
            mask[:, : batch.mask.shape[1]] = batch.mask
            batch.mask = mask

    assert batch.x.shape[1] <= max_length
    assert batch.y.shape[1] <= max_length
    assert batch.mask is None or batch.mask.shape[1] <= max_length
    assert np.all(max_length_adj - batch.patch_lengths.sum(axis=1) == 0)
    if pad_to_max_length:
        assert batch.x.shape[1] == max_length
        assert batch.y.shape[1] == max_length
        assert batch.mask is None or batch.mask.shape[1] == max_length
    if enable_byte_ngrams:
        raise NotImplementedError()
        # (num_ngram, batch_size, seq_len)
        ngram_ids = np.array(tokenizer.encode_token_ngrams(batch.x))
        assert ngram_ids.shape[2] == batch.x.shape[1]
    else:
        ngram_ids = None
    batch.ngram_ids = ngram_ids


class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
    def __init__(
        self,
        sequence_iterator: StatefulIterator[BltSequence, Any],
        *,
        packing_args: PackingArgs,
    ):
        self.sequence_iterator = sequence_iterator
        self.packing_args = packing_args

    def get_state(self):
        return PackingIteratorState(
            sequence_iterator_state=self.sequence_iterator.get_state(),
            packing_args=self.packing_args,
        )

    def create_iter(self):
        sequence_iter = self.sequence_iterator.create_iter()
        batch_size = self.packing_args.batch_size
        pad_id = self.packing_args.pad_id
        seq_len = self.packing_args.seq_len
        pad_to_max_length = self.packing_args.pad_to_max_length
        enable_byte_ngrams = self.packing_args.enable_byte_ngrams
        max_length = self.packing_args.max_length
        while True:
            tokens: list[list[int]] = []
            masks: list[list[bool]] = []
            patch_lengths: list[list[int]] = []

            for _ in range(self.packing_args.batch_size):
                sequence = next(sequence_iter)
                _tokens = sequence.tokens
                _mask = sequence.mask
                _patch_lengths = sequence.patch_lengths
                assert len(sequence.patch_lengths) == self.packing_args.seq_len
                last_patch_length = 0
                if _patch_lengths[0] > 1:
                    last_patch_length = _patch_lengths[-1]
                    _patch_lengths[0] -= 1
                    _patch_lengths = [1] + _patch_lengths[:-1]
                tokens.append(_tokens[: len(_tokens) - last_patch_length])
                masks.append(_mask[: len(_mask) - last_patch_length])
                patch_lengths.append(_patch_lengths)

            x_patch_lengths = np.array(patch_lengths)
            # pad batch to same length
            tok_seq_len = max([len(toks) for toks in tokens]) - 1
            x = np.full((batch_size, tok_seq_len), fill_value=pad_id)
            y = np.full((batch_size, tok_seq_len), fill_value=pad_id)

            for i, tok_seq in enumerate(tokens):
                x[i, : len(tok_seq) - 1] = tok_seq[:-1]
                y[i, : len(tok_seq) - 1] = tok_seq[1:]
                # Adjust patch lengths to match x
                x_patch_lengths[i, -1] += tok_seq_len - (len(tok_seq) - 1)

            assert x_patch_lengths.shape == (batch_size, seq_len)

            if enable_byte_ngrams:
                raise NotImplementedError()
            else:
                ngram_ids = None

            batch = Batch(
                x=x,
                y=y,
                patch_lengths=x_patch_lengths,
                ngram_ids=ngram_ids,
                mask=_merge_patch_seq_masks(batch_size, tok_seq_len, masks),
            )
            assert (
                x_patch_lengths.sum() == x.size + batch_size
            ), f"{x_patch_lengths.sum()} != {x.size + batch_size}"
            assert (
                batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
            ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
            assert np.all(
                x_patch_lengths[:, 0] == 1
            ), f"first patch should always be 1, {x_patch_lengths[:, 0]}"
            # cuda_gb_allocated = (torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024)
            # cuda_gb_reserved = torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024
            # print(f"dataloader cuda_gb_allocated: {cuda_gb_allocated}, cuda_gb_reserved: {cuda_gb_reserved}")
            truncate_batch(
                batch,
                max_length=max_length,
                pad_id=pad_id,
                pad_to_max_length=pad_to_max_length,
                enable_byte_ngrams=enable_byte_ngrams,
            )
            yield batch