Spaces:
Running
Running
feat(train): progress on pjit
Browse files- src/dalle_mini/data.py +0 -2
- tools/train/train.py +34 -31
src/dalle_mini/data.py
CHANGED
|
@@ -191,7 +191,6 @@ class Dataset:
|
|
| 191 |
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
| 192 |
batch,
|
| 193 |
)
|
| 194 |
-
batch = shard(batch)
|
| 195 |
yield batch
|
| 196 |
|
| 197 |
def _dataloader_datasets_streaming(
|
|
@@ -232,7 +231,6 @@ class Dataset:
|
|
| 232 |
),
|
| 233 |
batch,
|
| 234 |
)
|
| 235 |
-
batch = shard(batch)
|
| 236 |
yield batch
|
| 237 |
batch = {k: [] for k in keys}
|
| 238 |
first_loop = False
|
|
|
|
| 191 |
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
| 192 |
batch,
|
| 193 |
)
|
|
|
|
| 194 |
yield batch
|
| 195 |
|
| 196 |
def _dataloader_datasets_streaming(
|
|
|
|
| 231 |
),
|
| 232 |
batch,
|
| 233 |
)
|
|
|
|
| 234 |
yield batch
|
| 235 |
batch = {k: [] for k in keys}
|
| 236 |
first_loop = False
|
tools/train/train.py
CHANGED
|
@@ -34,13 +34,11 @@ import numpy as np
|
|
| 34 |
import optax
|
| 35 |
import transformers
|
| 36 |
from datasets import Dataset
|
| 37 |
-
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 38 |
-
from flax import
|
| 39 |
-
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 40 |
-
from flax.jax_utils import unreplicate
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
| 43 |
-
from flax.training.common_utils import get_metrics, onehot
|
| 44 |
from jax.experimental import PartitionSpec, maps
|
| 45 |
from jax.experimental.pjit import pjit
|
| 46 |
from tqdm import tqdm
|
|
@@ -402,14 +400,14 @@ class MetricsLogger:
|
|
| 402 |
|
| 403 |
def get_all_train_metrics(self, train_metrics, state):
|
| 404 |
"""Make a dict of training metrics to be logged"""
|
| 405 |
-
metrics =
|
| 406 |
# get state parameters
|
| 407 |
state_dict = {
|
| 408 |
-
k.split("_")[-1]:
|
| 409 |
for k in ["epoch", "train_time", "train_samples"]
|
| 410 |
}
|
| 411 |
# timing metrics
|
| 412 |
-
new_step = int(
|
| 413 |
new_time = time.perf_counter()
|
| 414 |
if new_step > self.step:
|
| 415 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
|
@@ -551,7 +549,7 @@ def main():
|
|
| 551 |
|
| 552 |
# Initialize our training
|
| 553 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
| 554 |
-
rng,
|
| 555 |
|
| 556 |
# Store some constant
|
| 557 |
num_epochs = training_args.num_train_epochs
|
|
@@ -681,34 +679,39 @@ def main():
|
|
| 681 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 682 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 683 |
|
| 684 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
|
|
|
| 686 |
params, opt_state = pjit(
|
| 687 |
lambda x: (x, optimizer.init(x)),
|
| 688 |
in_axis_resources=None,
|
| 689 |
out_axis_resources=(param_spec, opt_state_spec),
|
| 690 |
)(freeze(model.params))
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
opt_state
|
| 697 |
-
tx=optimizer,
|
| 698 |
-
dropout_rng=dropout_rng,
|
| 699 |
-
step=0,
|
| 700 |
-
)
|
| 701 |
-
|
| 702 |
-
# create PartitionSpec for state
|
| 703 |
-
state_spec = {
|
| 704 |
-
"params": param_spec,
|
| 705 |
-
"opt_state": opt_state_spec,
|
| 706 |
-
"dropout_rng": PartitionSpec("batch", None),
|
| 707 |
-
"epoch": None,
|
| 708 |
-
"step": None,
|
| 709 |
-
"train_samples": None,
|
| 710 |
-
"train_time": None,
|
| 711 |
-
}
|
| 712 |
|
| 713 |
if training_args.resume_from_checkpoint is not None:
|
| 714 |
# restore optimizer state and other parameters
|
|
|
|
| 34 |
import optax
|
| 35 |
import transformers
|
| 36 |
from datasets import Dataset
|
| 37 |
+
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 38 |
+
from flax.core.frozen_dict import freeze
|
|
|
|
|
|
|
| 39 |
from flax.serialization import from_bytes, to_bytes
|
| 40 |
from flax.training import train_state
|
| 41 |
+
from flax.training.common_utils import get_metrics, onehot
|
| 42 |
from jax.experimental import PartitionSpec, maps
|
| 43 |
from jax.experimental.pjit import pjit
|
| 44 |
from tqdm import tqdm
|
|
|
|
| 400 |
|
| 401 |
def get_all_train_metrics(self, train_metrics, state):
|
| 402 |
"""Make a dict of training metrics to be logged"""
|
| 403 |
+
metrics = train_metrics
|
| 404 |
# get state parameters
|
| 405 |
state_dict = {
|
| 406 |
+
k.split("_")[-1]: getattr(state, k)
|
| 407 |
for k in ["epoch", "train_time", "train_samples"]
|
| 408 |
}
|
| 409 |
# timing metrics
|
| 410 |
+
new_step = int(state.step)
|
| 411 |
new_time = time.perf_counter()
|
| 412 |
if new_step > self.step:
|
| 413 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
|
|
|
| 549 |
|
| 550 |
# Initialize our training
|
| 551 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
| 552 |
+
rng, dropout_rng = jax.random.split(rng)
|
| 553 |
|
| 554 |
# Store some constant
|
| 555 |
num_epochs = training_args.num_train_epochs
|
|
|
|
| 679 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 680 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 681 |
|
| 682 |
+
# Setup train state
|
| 683 |
+
def init_state(params, opt_state):
|
| 684 |
+
return TrainState(
|
| 685 |
+
apply_fn=model.__call__,
|
| 686 |
+
tx=optimizer,
|
| 687 |
+
params=params,
|
| 688 |
+
opt_state=opt_state,
|
| 689 |
+
dropout_rng=dropout_rng,
|
| 690 |
+
step=0,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
state_spec = init_state(param_spec, opt_state_spec)
|
| 694 |
+
state_spec = state_spec.replace(
|
| 695 |
+
dropout_rng=None,
|
| 696 |
+
step=None,
|
| 697 |
+
epoch=None,
|
| 698 |
+
train_time=None,
|
| 699 |
+
train_samples=None,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 703 |
+
# move params & init opt_state over specified devices
|
| 704 |
params, opt_state = pjit(
|
| 705 |
lambda x: (x, optimizer.init(x)),
|
| 706 |
in_axis_resources=None,
|
| 707 |
out_axis_resources=(param_spec, opt_state_spec),
|
| 708 |
)(freeze(model.params))
|
| 709 |
+
# create training state
|
| 710 |
+
state = pjit(
|
| 711 |
+
init_state,
|
| 712 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
| 713 |
+
out_axis_resources=state_spec,
|
| 714 |
+
)(params, opt_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
|
| 716 |
if training_args.resume_from_checkpoint is not None:
|
| 717 |
# restore optimizer state and other parameters
|