Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,100 Bytes
bcc039b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import pandas as pd
from pydantic import BaseModel
from bytelatent.constants import BLT_DATA
from bytelatent.data.data_types import BltExample
from bytelatent.data.iterators.abstract_iterator import IteratorState, StatefulIterator
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
class BltTestIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=f"This is some test {i} text.",
tokens=None,
mask=None,
entropies=None,
patch_lengths=None,
)
class BltTestWithEntropiesIteratorState(BaseModel, IteratorState):
position: int
total: int
def build(self):
blt_iter = BltTestWithEntropiesIteratorState(total=self.total)
blt_iter.position = self.position
return blt_iter
class BltTestWithEntropiesIterator(StatefulIterator):
def __init__(self, total: int):
self.position = 0
self.total = total
def get_state(self):
return BltTestIteratorState(position=self.position, total=self.total)
def create_iter(self):
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
df = pd.read_json("fixtures/tokens_with_entropies.json")
tokens = df["token_ids"].tolist()
entropies = df["entropies"].tolist()
# BOS and EOS
assert len(tokens) == len(text) + 2
for i in range(self.total):
self.position += 1
yield BltExample(
sample_id=f"test_{i}",
text=text,
tokens=tokens,
mask=[True] * len(tokens),
entropies=entropies,
patch_lengths=None,
)
def test_preprocess_iter():
total = 3
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
for mode in [
PatchingModeEnum.bpe,
PatchingModeEnum.space,
]:
data_it = BltTestIterator(total)
patcher_args = PatcherArgs(patching_mode=mode)
example_it = PreprocessIterator(
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
)
count = 0
for example in example_it.create_iter():
assert isinstance(example.tokens, list)
assert isinstance(example.tokens[0], int)
# BOS and EOS
assert len(example.tokens) == len(example.text) + 2
assert example.mask is not None
assert len(example.tokens) == len(example.mask)
count += 1
assert count == total
def test_non_entropy_patch_iter():
total = 3
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
for mode in [
PatchingModeEnum.bpe,
PatchingModeEnum.space,
]:
patcher_args = PatcherArgs(patching_mode=mode)
data_it = BltTestIterator(total)
example_it = PreprocessIterator(
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
)
count = 0
for example in example_it.create_iter():
assert isinstance(example.patch_lengths, list)
assert isinstance(example.patch_lengths[0], int)
assert len(example.tokens) == sum(example.patch_lengths)
count += 1
assert count == total
def test_entropy_patch_iter():
total = 2
patcher_args = PatcherArgs(
patching_mode=PatchingModeEnum.entropy, threshold=1.335442066192627
)
tokenizer_args = TokenizerArgs(
name="blt",
init_kwargs={
"bpe_tokenizer_path": BLT_DATA / "tokenizer_final_32k.minus_inf_ws.model"
},
)
data_it = BltTestWithEntropiesIterator(total)
example_it = PreprocessIterator(
data_it, tokenizer_args=tokenizer_args, patcher_args=patcher_args
)
count = 0
for example in example_it.create_iter():
assert isinstance(example.patch_lengths, list)
assert isinstance(example.patch_lengths[0], int)
assert len(example.tokens) == sum(example.patch_lengths)
count += 1
assert count == total
|