Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,691 Bytes
bcc039b 6ffeb66 bcc039b b0120da bcc039b b0120da bcc039b b0120da bcc039b b0120da bcc039b b0120da bcc039b 6ffeb66 bcc039b 6ffeb66 bcc039b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import logging
import os
from typing import Any
import numpy as np
import yaml
from pydantic import BaseModel, ConfigDict
from bytelatent.checkpoint import CheckpointArgs
from bytelatent.data.data_types import Batch
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
from bytelatent.data.iterators.arrow_iterator import (
ArrowFileIterator,
find_and_sanitize_chunks,
)
from bytelatent.data.iterators.looping_iterator import LoopingIterator
from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
from bytelatent.data.iterators.packing_iterator import PackingArgs, PackingIterator
from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
from bytelatent.data.iterators.sampling_iterator import SamplingIterator
from bytelatent.data.iterators.sequence_iterator import (
SequenceIterator,
SequencePackingArgs,
)
from bytelatent.data.patcher import PatcherArgs
from bytelatent.distributed import DistributedArgs, EnvironmentArgs
from bytelatent.metrics import LoggingArgs
from bytelatent.model.blt import ByteLatentTransformerArgs
from bytelatent.optim import OptimArgs
from bytelatent.profiling import ProfilerArgs
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
from bytelatent.transformer import LMTransformerArgs
logger = logging.getLogger()
def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state
def distribute_data_to_rank(
*,
dataset_path: str,
preprocess_dir: str,
entropy_model_name: str | None,
arrow_batch_size: int,
rank: int,
world_size: int,
s3_profile: str | None = None,
) -> ArrowFileIterator:
dataset_chunks = find_and_sanitize_chunks(
dataset_path, world_size, s3_profile=s3_profile
)
n_workers_per_chunk = world_size // len(dataset_chunks)
rank_to_arrow_iterator_params = []
for chunk_path in dataset_chunks:
for worker_id in range(n_workers_per_chunk):
rank_to_arrow_iterator_params.append(
ArrowFileIterator(
file_path=chunk_path,
worker_id=worker_id,
num_workers=n_workers_per_chunk,
preprocess_dir=preprocess_dir,
dataset_files=None,
entropy_model_name=entropy_model_name,
arrow_batch_size=arrow_batch_size,
s3_profile=s3_profile,
)
)
return rank_to_arrow_iterator_params[rank]
class DataloaderArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
s3_profile: str | None = None
root_dir: str | None = None
sources: dict[str, float] = {}
batch_size: int = 2
seq_len: int = 2048
seed: int = 42
add_bos: bool = True
add_eos: bool = True
load_async: bool = True
prefetch_size: int = 64
preprocess_dir: str | None = None
dataset_files: list[str] | None = None
entropy_model_name: str | None = "transformer_100m"
arrow_batch_size: int = 100
buffer_size: int = 64
pad_to_max_length: bool = True
max_encoder_seq_length: int = 12288
enable_byte_ngrams: bool = False
tokenizer_args: TokenizerArgs = TokenizerArgs()
patcher_args: PatcherArgs = PatcherArgs()
def _create_sequence_iterators(
self, rank: int, world_size: int
) -> dict[str, SequenceIterator]:
sequence_packing_args = SequencePackingArgs(
output_seq_len=self.seq_len,
buffer_size=self.buffer_size,
)
source_to_sequence_iterator: dict[str, SequenceIterator] = {}
for dataset_path in self.sources:
shuffle_rng_state = get_rng_state(self.seed + 1, rank, world_size)
arrow_iterator = distribute_data_to_rank(
dataset_path=os.path.join(self.root_dir, dataset_path),
preprocess_dir=self.preprocess_dir,
entropy_model_name=self.entropy_model_name,
arrow_batch_size=self.arrow_batch_size,
rank=rank,
world_size=world_size,
s3_profile=self.s3_profile,
)
looping_iterator = LoopingIterator(arrow_iterator)
preprocess_iterator = PreprocessIterator(
looping_iterator,
patcher_args=self.patcher_args,
tokenizer_args=self.tokenizer_args,
)
sequence_iterator = SequenceIterator(
preprocess_iterator,
sequence_packing_args=sequence_packing_args,
rng_state=shuffle_rng_state,
)
source_to_sequence_iterator[dataset_path] = sequence_iterator
return source_to_sequence_iterator
def build_from_rank(
self, rank: int, world_size: int
) -> StatefulIterator[Batch, Any]:
source_to_sequence_iterators = self._create_sequence_iterators(rank, world_size)
weight_rng_state = get_rng_state(self.seed + 1, rank, world_size)
sampling_iterator = SamplingIterator(
rng_state=weight_rng_state,
source_to_weight=self.sources,
source_to_iterator=source_to_sequence_iterators,
)
tokenizer = self.tokenizer_args.build()
packing_args = PackingArgs(
batch_size=self.batch_size,
seq_len=self.seq_len,
pad_id=tokenizer.boe_id,
max_length=self.max_encoder_seq_length,
pad_to_max_length=self.pad_to_max_length,
enable_byte_ngrams=self.enable_byte_ngrams,
)
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
mp_iterator = MultiprocessIterator(
packing_iterator, n_batches_to_prefetch=self.prefetch_size
)
return mp_iterator
class TrainArgs(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str = "lingua"
dump_dir: str = ""
seed: int = 42
debug_dynamo: bool = False
# Number of gradient accumulation steps
# Total batch size is batch_size*grad_acc_steps
grad_acc_steps: int = 1
gc_collect_freq: int = 1000
probe_freq: int | None = None
# Nb optimizer steps to take
steps: int = 1000
data: DataloaderArgs = DataloaderArgs()
optim: OptimArgs = OptimArgs()
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
# This is only needed for training the entropy model
entropy_model: LMTransformerArgs | None = None
# Instead of training main model, train entropy model
train_entropy_model: bool = False
distributed: DistributedArgs = DistributedArgs()
env: EnvironmentArgs = EnvironmentArgs()
checkpoint: CheckpointArgs = CheckpointArgs()
profiling: ProfilerArgs = ProfilerArgs()
logging: LoggingArgs = LoggingArgs()
# If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
async_eval_gpus: int | None = None
eval: Any | None = None
eval_on_gpus: int | None = None
def dump_to_yaml_file(
self, path: str, log_config: bool = True, sort_keys: bool = True
):
model_dict = self.model_dump(mode="json")
yaml_str = yaml.dump(
model_dict,
allow_unicode=True,
sort_keys=sort_keys,
default_flow_style=False,
)
with open(path, "w") as f:
if log_config:
logger.info("Using the following config for this run:")
logger.info(yaml_str)
f.write(yaml_str)
|