File size: 5,100 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pandas as pd
from pydantic import BaseModel

from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs


class BltTestIteratorState(BaseModel, IteratorState):
    position: int
    total: int

    def build(self):
        blt_iter = BltTestIteratorState(total=self.total)
        blt_iter.position = self.position
        return blt_iter


class BltTestIterator(StatefulIterator):
    def __init__(self, total: int):
        self.position = 0
        self.total = total

    def get_state(self):
        return BltTestIteratorState(position=self.position, total=self.total)

    def create_iter(self):
        for i in range(self.total):
            self.position += 1
            yield BltExample(
                sample_id=f"test_{i}",
                text=f"This is some test {i} text.",
                tokens=None,
                mask=None,
                entropies=None,
                patch_lengths=None,
            )


class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
    position: int
    total: int

    def build(self):
        blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
        blt_iter.position = self.position
        return blt_iter


class BltTestWithEntropiesIterator(StatefulIterator):
    def __init__(self, total: int):
        self.position = 0
        self.total = total

    def get_state(self):
        return BltTestIteratorState(position=self.position, total=self.total)

    def create_iter(self):
        text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
        df = pd.read_json("fixtures/tokens_with_entropies.json")
        tokens = df["token_ids"].tolist()
        entropies = df["entropies"].tolist()
        # BOS and EOS
        assert len(tokens) == len(text) + 2
        for i in range(self.total):
            self.position += 1
            yield BltExample(
                sample_id=f"test_{i}",
                text=text,
                tokens=tokens,
                mask=[True] * len(tokens),
                entropies=entropies,
                patch_lengths=None,
            )


def test_preprocess_iter():
    total = 3
    tokenizer_args = TokenizerArgs(
        name="blt",
        init_kwargs={
            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
        },
    )
    for mode in [
        PatchingModeEnum.bpe,
        PatchingModeEnum.space,
    ]:
        data_it = BltTestIterator(total)
        patcher_args = PatcherArgs(patching_mode=mode)
        example_it = PreprocessIterator(
            data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
        )
        count = 0
        for example in example_it.create_iter():
            assert isinstance(example.tokens, list)
            assert isinstance(example.tokens[0], int)
            # BOS and EOS
            assert len(example.tokens) == len(example.text) + 2
            assert example.mask is not None
            assert len(example.tokens) == len(example.mask)
            count += 1

        assert count == total


def test_non_entropy_patch_iter():
    total = 3
    tokenizer_args = TokenizerArgs(
        name="blt",
        init_kwargs={
            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
        },
    )
    for mode in [
        PatchingModeEnum.bpe,
        PatchingModeEnum.space,
    ]:
        patcher_args = PatcherArgs(patching_mode=mode)
        data_it = BltTestIterator(total)
        example_it = PreprocessIterator(
            data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
        )

        count = 0
        for example in example_it.create_iter():
            assert isinstance(example.patch_lengths, list)
            assert isinstance(example.patch_lengths[0], int)
            assert len(example.tokens) == sum(example.patch_lengths)
            count += 1

        assert count == total


def test_entropy_patch_iter():
    total = 2
    patcher_args = PatcherArgs(
        patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
    )
    tokenizer_args = TokenizerArgs(
        name="blt",
        init_kwargs={
            "bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
        },
    )
    data_it = BltTestWithEntropiesIterator(total)
    example_it = PreprocessIterator(
        data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
    )

    count = 0
    for example in example_it.create_iter():
        assert isinstance(example.patch_lengths, list)
        assert isinstance(example.patch_lengths[0], int)
        assert len(example.tokens) == sum(example.patch_lengths)
        count += 1

    assert count == total