Spaces:
Running
on
Zero
Running
on
Zero
Update preprocess_entropies script to blt inference + add fsspec support (#23)
Browse files- bytelatent/data/patcher.py +4 -4
- bytelatent/preprocess/preprocess_entropies.py +69 -36
- requirements.txt +1 -0
bytelatent/data/patcher.py
CHANGED
|
@@ -82,16 +82,16 @@ def calculate_entropies(
|
|
| 82 |
if device is not None:
|
| 83 |
split = split.to(device)
|
| 84 |
assert torch.all(split >= 0) and torch.all(split < 260)
|
| 85 |
-
pred
|
| 86 |
pred = pred.reshape(-1, pred.shape[-1])[
|
| 87 |
: split.numel() - pad_size, :
|
| 88 |
] # [batch_size * seq_len, vocab]
|
| 89 |
pred_entropies = entropy(pred)
|
| 90 |
entropies.append(pred_entropies)
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
return
|
| 95 |
|
| 96 |
|
| 97 |
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
|
|
|
| 82 |
if device is not None:
|
| 83 |
split = split.to(device)
|
| 84 |
assert torch.all(split >= 0) and torch.all(split < 260)
|
| 85 |
+
pred = entropy_model(split)
|
| 86 |
pred = pred.reshape(-1, pred.shape[-1])[
|
| 87 |
: split.numel() - pad_size, :
|
| 88 |
] # [batch_size * seq_len, vocab]
|
| 89 |
pred_entropies = entropy(pred)
|
| 90 |
entropies.append(pred_entropies)
|
| 91 |
|
| 92 |
+
concat_entropies = torch.cat(entropies, dim=0)
|
| 93 |
+
concat_entropies = concat_entropies.reshape(tokens.shape)
|
| 94 |
+
return concat_entropies
|
| 95 |
|
| 96 |
|
| 97 |
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
bytelatent/preprocess/preprocess_entropies.py
CHANGED
|
@@ -1,14 +1,59 @@
|
|
| 1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
import time
|
| 3 |
-
from pathlib import Path
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import pyarrow as pa
|
| 7 |
import torch
|
| 8 |
import typer
|
| 9 |
from rich.progress import Progress, TextColumn
|
| 10 |
|
| 11 |
-
from bytelatent.data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def main(
|
|
@@ -16,39 +61,32 @@ def main(
|
|
| 16 |
output_file: str,
|
| 17 |
patching_device: str = "cuda",
|
| 18 |
log_step: int = 10_000,
|
| 19 |
-
entropy_model_checkpoint_dir: str = "
|
|
|
|
|
|
|
| 20 |
dry_run: bool = False,
|
|
|
|
| 21 |
):
|
| 22 |
-
# TODO: Modify this to work with the new code
|
| 23 |
-
raise NotImplementedError()
|
| 24 |
-
iterator = ArrowFileIterator(
|
| 25 |
-
file_path=input_file,
|
| 26 |
-
worker_id=0,
|
| 27 |
-
num_workers=1,
|
| 28 |
-
)
|
| 29 |
-
tokenization_mode = "bytes"
|
| 30 |
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
| 31 |
print("Loading entropy model", entropy_model_checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
| 32 |
if dry_run:
|
| 33 |
return
|
| 34 |
entropy_model = load_entropy_model(
|
| 35 |
-
entropy_model_checkpoint_dir,
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
-
|
| 38 |
print("Creating patcher")
|
| 39 |
patching_batch_size = 32
|
| 40 |
print("Creating tokenizer")
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
tokenization_mode=tokenization_mode,
|
| 44 |
-
# BYTE_UNITS
|
| 45 |
-
vocab_size_unit_1=256,
|
| 46 |
-
bos=True,
|
| 47 |
-
eos=True,
|
| 48 |
-
bpe_delim=False,
|
| 49 |
-
# This isn't used, just stores a reference for other calls we don't use
|
| 50 |
-
patcher=None,
|
| 51 |
)
|
|
|
|
| 52 |
step = 0
|
| 53 |
print("starting")
|
| 54 |
start_time = time.time()
|
|
@@ -59,8 +97,10 @@ def main(
|
|
| 59 |
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
| 60 |
arrow_batch_size = 1_000
|
| 61 |
|
|
|
|
|
|
|
| 62 |
try:
|
| 63 |
-
with
|
| 64 |
with pa.ipc.new_file(sink, schema) as writer:
|
| 65 |
id_buffer = []
|
| 66 |
entropies_buffer = []
|
|
@@ -72,17 +112,9 @@ def main(
|
|
| 72 |
task = progress.add_task(
|
| 73 |
"[green]Calculating entropies...", total=None
|
| 74 |
)
|
| 75 |
-
for doc in
|
| 76 |
sample_id = get_id_from_doc(doc)
|
| 77 |
-
|
| 78 |
-
if "text" in doc:
|
| 79 |
-
text = doc["text"]
|
| 80 |
-
elif "content" in doc:
|
| 81 |
-
text = doc["content"]
|
| 82 |
-
else:
|
| 83 |
-
raise ValueError(
|
| 84 |
-
f"Could not find a text key from: {doc.keys()}"
|
| 85 |
-
)
|
| 86 |
tokens = torch.tensor(tokenizer.encode(text))
|
| 87 |
patch_start = time.time()
|
| 88 |
scores = calculate_entropies(
|
|
@@ -128,9 +160,10 @@ def main(
|
|
| 128 |
entropies_buffer = []
|
| 129 |
id_buffer = []
|
| 130 |
text_buffer = []
|
| 131 |
-
|
| 132 |
except:
|
| 133 |
-
|
|
|
|
| 134 |
raise
|
| 135 |
elapsed = time.time() - start_time
|
| 136 |
print("steps", step)
|
|
|
|
| 1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
import time
|
|
|
|
| 3 |
|
| 4 |
+
import fsspec
|
| 5 |
+
import jsonlines
|
| 6 |
import numpy as np
|
| 7 |
import pyarrow as pa
|
| 8 |
import torch
|
| 9 |
import typer
|
| 10 |
from rich.progress import Progress, TextColumn
|
| 11 |
|
| 12 |
+
from bytelatent.data.file_util import get_fs
|
| 13 |
+
from bytelatent.data.patcher import calculate_entropies
|
| 14 |
+
from bytelatent.entropy_model import load_entropy_model
|
| 15 |
+
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_id_from_doc(doc: dict) -> int:
|
| 19 |
+
"""
|
| 20 |
+
We need a reliable way to ensure that samples from jsonl
|
| 21 |
+
and arrow are the same, but there is no unique id field,
|
| 22 |
+
so derive the best possible
|
| 23 |
+
"""
|
| 24 |
+
if "sample_id" in doc:
|
| 25 |
+
sample_id = doc["sample_id"]
|
| 26 |
+
elif "title" in doc:
|
| 27 |
+
sample_id = doc["title"]
|
| 28 |
+
elif "qid" in doc:
|
| 29 |
+
sample_id = doc["qid"]
|
| 30 |
+
elif "paper_id" in doc:
|
| 31 |
+
sample_id = doc["paper_id"]
|
| 32 |
+
elif "path" in doc:
|
| 33 |
+
sample_id = doc["path"]
|
| 34 |
+
elif "url" in doc:
|
| 35 |
+
sample_id = doc["url"]
|
| 36 |
+
elif "id" in doc:
|
| 37 |
+
sample_id = doc["id"]
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
| 40 |
+
return str(sample_id)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_text(doc: dict):
|
| 44 |
+
if "text" in doc:
|
| 45 |
+
text = doc["text"]
|
| 46 |
+
elif "content" in doc:
|
| 47 |
+
text = doc["content"]
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f"Could not find a text key from: {doc.keys()}")
|
| 50 |
+
return text
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
|
| 54 |
+
with fs.open(path) as f:
|
| 55 |
+
reader = jsonlines.Reader(f)
|
| 56 |
+
yield from reader
|
| 57 |
|
| 58 |
|
| 59 |
def main(
|
|
|
|
| 61 |
output_file: str,
|
| 62 |
patching_device: str = "cuda",
|
| 63 |
log_step: int = 10_000,
|
| 64 |
+
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
|
| 65 |
+
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
|
| 66 |
+
bpe_tokenizer_path: str = "public_data/tokenizer.model",
|
| 67 |
dry_run: bool = False,
|
| 68 |
+
s3_profile: str | None = None,
|
| 69 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
| 71 |
print("Loading entropy model", entropy_model_checkpoint_dir)
|
| 72 |
+
input_fs = get_fs(input_file, s3_profile=s3_profile)
|
| 73 |
+
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
|
| 74 |
+
|
| 75 |
if dry_run:
|
| 76 |
return
|
| 77 |
entropy_model = load_entropy_model(
|
| 78 |
+
entropy_model_checkpoint_dir,
|
| 79 |
+
entropy_model_state_dict_path,
|
| 80 |
+
device=patching_device,
|
| 81 |
)
|
| 82 |
+
|
| 83 |
print("Creating patcher")
|
| 84 |
patching_batch_size = 32
|
| 85 |
print("Creating tokenizer")
|
| 86 |
+
tokenizer_args = TokenizerArgs(
|
| 87 |
+
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
)
|
| 89 |
+
tokenizer = tokenizer_args.build()
|
| 90 |
step = 0
|
| 91 |
print("starting")
|
| 92 |
start_time = time.time()
|
|
|
|
| 97 |
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
| 98 |
arrow_batch_size = 1_000
|
| 99 |
|
| 100 |
+
output_fs = get_fs(output_file, s3_profile=s3_profile)
|
| 101 |
+
|
| 102 |
try:
|
| 103 |
+
with output_fs.open(output_file, "wb") as sink:
|
| 104 |
with pa.ipc.new_file(sink, schema) as writer:
|
| 105 |
id_buffer = []
|
| 106 |
entropies_buffer = []
|
|
|
|
| 112 |
task = progress.add_task(
|
| 113 |
"[green]Calculating entropies...", total=None
|
| 114 |
)
|
| 115 |
+
for doc in input_doc_iterator:
|
| 116 |
sample_id = get_id_from_doc(doc)
|
| 117 |
+
text = get_text(doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
tokens = torch.tensor(tokenizer.encode(text))
|
| 119 |
patch_start = time.time()
|
| 120 |
scores = calculate_entropies(
|
|
|
|
| 160 |
entropies_buffer = []
|
| 161 |
id_buffer = []
|
| 162 |
text_buffer = []
|
| 163 |
+
output_fs.touch(f"{output_file}.complete")
|
| 164 |
except:
|
| 165 |
+
if output_fs.exists(output_file):
|
| 166 |
+
output_fs.rm(output_file)
|
| 167 |
raise
|
| 168 |
elapsed = time.time() - start_time
|
| 169 |
print("steps", step)
|
requirements.txt
CHANGED
|
@@ -21,3 +21,4 @@ submitit
|
|
| 21 |
typer
|
| 22 |
rich
|
| 23 |
fsspec[full]
|
|
|
|
|
|
| 21 |
typer
|
| 22 |
rich
|
| 23 |
fsspec[full]
|
| 24 |
+
orjson
|