Spaces:
Running
Running
feat: gradient accumulation
Browse files- seq2seq/run_seq2seq_flax.py +37 -10
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -239,6 +239,8 @@ class DataTrainingArguments:
|
|
| 239 |
|
| 240 |
class TrainState(train_state.TrainState):
|
| 241 |
dropout_rng: jnp.ndarray
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def replicate(self):
|
| 244 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
@@ -590,14 +592,16 @@ def main():
|
|
| 590 |
# Store some constant
|
| 591 |
num_epochs = int(training_args.num_train_epochs)
|
| 592 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
|
|
|
| 593 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 594 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 595 |
-
|
|
|
|
| 596 |
|
| 597 |
# Create learning rate schedule
|
| 598 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
| 599 |
len(train_dataset),
|
| 600 |
-
|
| 601 |
training_args.num_train_epochs,
|
| 602 |
training_args.warmup_steps,
|
| 603 |
training_args.learning_rate,
|
|
@@ -636,7 +640,14 @@ def main():
|
|
| 636 |
)
|
| 637 |
|
| 638 |
# Setup train state
|
| 639 |
-
state = TrainState.create(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
# label smoothed cross entropy
|
| 642 |
def loss_fn(logits, labels):
|
|
@@ -655,15 +666,28 @@ def main():
|
|
| 655 |
return loss
|
| 656 |
|
| 657 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 658 |
-
loss,
|
| 659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
-
new_state =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
|
| 663 |
-
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.
|
| 664 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 665 |
|
| 666 |
-
return new_state, metrics
|
| 667 |
|
| 668 |
# Define eval fn
|
| 669 |
def eval_step(params, batch):
|
|
@@ -702,8 +726,11 @@ def main():
|
|
| 702 |
logger.info(f" Num examples = {len(train_dataset)}")
|
| 703 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 704 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
| 705 |
-
logger.info(
|
| 706 |
-
|
|
|
|
|
|
|
|
|
|
| 707 |
|
| 708 |
train_time = 0
|
| 709 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
|
| 239 |
|
| 240 |
class TrainState(train_state.TrainState):
|
| 241 |
dropout_rng: jnp.ndarray
|
| 242 |
+
grad_accum: jnp.ndarray
|
| 243 |
+
optimizer_step: int
|
| 244 |
|
| 245 |
def replicate(self):
|
| 246 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
|
|
| 592 |
# Store some constant
|
| 593 |
num_epochs = int(training_args.num_train_epochs)
|
| 594 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
| 595 |
+
total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
|
| 596 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
| 597 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
| 598 |
+
total_steps = steps_per_epoch * num_epochs
|
| 599 |
+
total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
|
| 600 |
|
| 601 |
# Create learning rate schedule
|
| 602 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
| 603 |
len(train_dataset),
|
| 604 |
+
total_batch_size,
|
| 605 |
training_args.num_train_epochs,
|
| 606 |
training_args.warmup_steps,
|
| 607 |
training_args.learning_rate,
|
|
|
|
| 640 |
)
|
| 641 |
|
| 642 |
# Setup train state
|
| 643 |
+
state = TrainState.create(
|
| 644 |
+
apply_fn=model.__call__,
|
| 645 |
+
params=model.params,
|
| 646 |
+
tx=adamw,
|
| 647 |
+
dropout_rng=dropout_rng,
|
| 648 |
+
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
| 649 |
+
optimizer_step=0,
|
| 650 |
+
)
|
| 651 |
|
| 652 |
# label smoothed cross entropy
|
| 653 |
def loss_fn(logits, labels):
|
|
|
|
| 666 |
return loss
|
| 667 |
|
| 668 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 669 |
+
loss, grads = grad_fn(state.params)
|
| 670 |
+
grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
|
| 671 |
+
|
| 672 |
+
def update_fn():
|
| 673 |
+
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
| 674 |
+
grads = jax.lax.pmean(grads, "batch")
|
| 675 |
+
new_state = state.apply_gradients(
|
| 676 |
+
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
|
| 677 |
+
)
|
| 678 |
+
return new_state
|
| 679 |
|
| 680 |
+
new_state = jax.lax.cond(
|
| 681 |
+
state.step % training_args.gradient_accumulation_steps == 0,
|
| 682 |
+
lambda _: update_fn(),
|
| 683 |
+
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
| 684 |
+
None,
|
| 685 |
+
)
|
| 686 |
|
| 687 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
|
| 688 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
| 689 |
|
| 690 |
+
return new_state.replace(dropout_rng=new_dropout_rng), metrics
|
| 691 |
|
| 692 |
# Define eval fn
|
| 693 |
def eval_step(params, batch):
|
|
|
|
| 726 |
logger.info(f" Num examples = {len(train_dataset)}")
|
| 727 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 728 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
| 729 |
+
logger.info(
|
| 730 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
|
| 731 |
+
)
|
| 732 |
+
logger.info(f" Total global steps = {total_steps}")
|
| 733 |
+
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
| 734 |
|
| 735 |
train_time = 0
|
| 736 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|