Spaces:
Running
Running
feat: split shards by host
Browse files- dalle_mini/data.py +30 -11
dalle_mini/data.py
CHANGED
|
@@ -4,9 +4,9 @@ from functools import partial
|
|
| 4 |
import jax
|
| 5 |
import jax.numpy as jnp
|
| 6 |
import numpy as np
|
|
|
|
| 7 |
from datasets import Dataset, load_dataset
|
| 8 |
from flax.training.common_utils import shard
|
| 9 |
-
from braceexpand import braceexpand
|
| 10 |
|
| 11 |
from .text import TextNormalizer
|
| 12 |
|
|
@@ -30,8 +30,10 @@ class Dataset:
|
|
| 30 |
train_dataset: Dataset = field(init=False)
|
| 31 |
eval_dataset: Dataset = field(init=False)
|
| 32 |
rng_dataset: jnp.ndarray = field(init=False)
|
|
|
|
| 33 |
|
| 34 |
def __post_init__(self):
|
|
|
|
| 35 |
# define data_files
|
| 36 |
if self.train_file is not None or self.validation_file is not None:
|
| 37 |
# accept braceexpand notation
|
|
@@ -39,6 +41,11 @@ class Dataset:
|
|
| 39 |
f = getattr(self, k)
|
| 40 |
if isinstance(f, str):
|
| 41 |
setattr(self, k, list(braceexpand(f)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
data_files = {
|
| 43 |
"train": self.train_file,
|
| 44 |
"validation": self.validation_file,
|
|
@@ -169,17 +176,29 @@ class Dataset:
|
|
| 169 |
batch = shard(batch)
|
| 170 |
yield batch
|
| 171 |
|
| 172 |
-
def _dataloader_datasets_streaming(
|
|
|
|
|
|
|
|
|
|
| 173 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 174 |
batch = {k: [] for k in keys}
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
if split == "train":
|
| 185 |
ds = self.train_dataset
|
|
@@ -191,7 +210,7 @@ class Dataset:
|
|
| 191 |
if self.streaming:
|
| 192 |
if split == "train":
|
| 193 |
ds.set_epoch(epoch)
|
| 194 |
-
return _dataloader_datasets_streaming(ds, batch_size)
|
| 195 |
else:
|
| 196 |
if split == "train":
|
| 197 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
|
|
|
| 4 |
import jax
|
| 5 |
import jax.numpy as jnp
|
| 6 |
import numpy as np
|
| 7 |
+
from braceexpand import braceexpand
|
| 8 |
from datasets import Dataset, load_dataset
|
| 9 |
from flax.training.common_utils import shard
|
|
|
|
| 10 |
|
| 11 |
from .text import TextNormalizer
|
| 12 |
|
|
|
|
| 30 |
train_dataset: Dataset = field(init=False)
|
| 31 |
eval_dataset: Dataset = field(init=False)
|
| 32 |
rng_dataset: jnp.ndarray = field(init=False)
|
| 33 |
+
multi_hosts: bool = field(init=False)
|
| 34 |
|
| 35 |
def __post_init__(self):
|
| 36 |
+
self.multi_hosts = jax.process_count > 1
|
| 37 |
# define data_files
|
| 38 |
if self.train_file is not None or self.validation_file is not None:
|
| 39 |
# accept braceexpand notation
|
|
|
|
| 41 |
f = getattr(self, k)
|
| 42 |
if isinstance(f, str):
|
| 43 |
setattr(self, k, list(braceexpand(f)))
|
| 44 |
+
# for list of files, split training data shards by host
|
| 45 |
+
if isinstance(self.train_file, list) and self.multi_hosts:
|
| 46 |
+
self.train_file = self.train_file[
|
| 47 |
+
jax.process_index() :: jax.process_count()
|
| 48 |
+
]
|
| 49 |
data_files = {
|
| 50 |
"train": self.train_file,
|
| 51 |
"validation": self.validation_file,
|
|
|
|
| 176 |
batch = shard(batch)
|
| 177 |
yield batch
|
| 178 |
|
| 179 |
+
def _dataloader_datasets_streaming(
|
| 180 |
+
dataset: Dataset, batch_size: int, epoch: int
|
| 181 |
+
):
|
| 182 |
+
# epoch is only use for multi-host
|
| 183 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 184 |
batch = {k: [] for k in keys}
|
| 185 |
+
first_loop = True
|
| 186 |
+
while self.multi_hosts or first_loop:
|
| 187 |
+
# in multi-host, we run forever (no epoch) as hosts need to stop
|
| 188 |
+
# at same the time and we don't know how much data is on each host
|
| 189 |
+
if not first_loop:
|
| 190 |
+
# multi-host setting, we reshuffle shards
|
| 191 |
+
epoch += 1
|
| 192 |
+
dataset.set_epoch(epoch)
|
| 193 |
+
for item in dataset:
|
| 194 |
+
for k, v in item.items():
|
| 195 |
+
batch[k].append(v)
|
| 196 |
+
if len(batch[keys[0]]) == batch_size:
|
| 197 |
+
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 198 |
+
batch = shard(batch)
|
| 199 |
+
yield batch
|
| 200 |
+
batch = {k: [] for k in keys}
|
| 201 |
+
first_loop = False
|
| 202 |
|
| 203 |
if split == "train":
|
| 204 |
ds = self.train_dataset
|
|
|
|
| 210 |
if self.streaming:
|
| 211 |
if split == "train":
|
| 212 |
ds.set_epoch(epoch)
|
| 213 |
+
return _dataloader_datasets_streaming(ds, batch_size, epoch)
|
| 214 |
else:
|
| 215 |
if split == "train":
|
| 216 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|