File size: 4,778 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Meta Platforms, Inc. and affiliates.
from logging import getLogger
from typing import Any

import numpy as np
from pydantic import BaseModel, ConfigDict

from bytelatent.data.data_types import BltSequence
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.preprocess_iterator import (
    PreprocessIterator,
    PreprocessIteratorState,
)

logger = getLogger()


class SequencePackingArgs(BaseModel):
    model_config = ConfigDict(extra="forbid")
    output_seq_len: int
    buffer_size: int


class SequenceIteratorState(BaseModel, IteratorState):
    model_config = ConfigDict(extra="forbid")
    sequence_packing_args: SequencePackingArgs
    preprocess_iterator_state: PreprocessIteratorState
    rng_state: dict[str, Any]

    def build(self):
        preprocess_iterator = self.preprocess_iterator_state.build()
        return SequenceIterator(
            preprocess_iterator,
            sequence_packing_args=self.sequence_packing_args,
            rng_state=self.rng_state,
        )


class SequenceIterator(StatefulIterator):
    def __init__(
        self,
        preprocess_iterator: PreprocessIterator,
        *,
        rng_state: dict[str, Any],
        sequence_packing_args: SequencePackingArgs,
    ):
        self.preprocess_iterator = preprocess_iterator
        self.sequence_packing_args = sequence_packing_args
        self.output_seq_len = sequence_packing_args.output_seq_len
        self.buffer_size = sequence_packing_args.buffer_size
        self.rng = np.random.default_rng()
        self.rng.bit_generator.state = rng_state

    def get_state(self):
        # TODO: need to also perist the current shuffle buffer
        return SequenceIteratorState(
            sequence_packing_args=self.sequence_packing_args,
            preprocess_iterator_state=self.preprocess_iterator.get_state(),
            rng_state=self.rng.bit_generator.state,
        )

    def create_iter(self):
        example_iter = self.preprocess_iterator.create_iter()
        n_buffer_patches = self.buffer_size * self.output_seq_len

        patch_lengths: list[int] = []
        tokens: list[int] = []
        mask: list[bool] = []
        first = True
        for example in example_iter:
            assert example.tokens is not None
            assert example.mask is not None
            assert example.patch_lengths is not None
            assert len(example.tokens) != 0
            assert len(example.mask) != 0
            assert len(example.tokens) == len(example.mask)
            assert len(example.tokens) == sum(example.patch_lengths)

            tokens.extend(example.tokens)
            mask.extend(example.mask)
            patch_lengths.extend(example.patch_lengths)

            while len(patch_lengths) >= n_buffer_patches:
                if first:
                    first = False
                    logger.info("First buffer complete")

                x_patches = np.array(patch_lengths[:n_buffer_patches]).reshape(
                    self.buffer_size, self.output_seq_len
                )
                seq_tokens = []
                seq_mask = []
                start_id = 0
                # We fix the number of patches and therefore global steps per batch
                # so we have a variable number of tokens we need to account for
                for num_tokens in x_patches.sum(axis=-1):
                    seq_tokens.append(tokens[start_id : start_id + num_tokens])
                    seq_mask.append(mask[start_id : start_id + num_tokens])
                    start_id += num_tokens

                assert start_id == x_patches.sum()

                # Remove what we just added from the buffer
                patch_lengths = patch_lengths[n_buffer_patches:]
                tokens = tokens[x_patches.sum() :]
                mask = mask[x_patches.sum() :]

                seq_patch_lengths: list[list[int]] = x_patches.tolist()
                assert len(seq_patch_lengths) == self.buffer_size
                for idx in self.rng.permutation(len(seq_patch_lengths)):
                    assert len(seq_patch_lengths[idx]) == self.output_seq_len
                    assert (
                        sum(seq_patch_lengths[idx])
                        == len(seq_tokens[idx])
                        == len(seq_mask[idx])
                    ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
                    assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
                    yield BltSequence(
                        tokens=seq_tokens[idx],
                        mask=seq_mask[idx],
                        patch_lengths=seq_patch_lengths[idx],
                    )