Spaces:
Running
Running
feat: refactor TrainingArguments
Browse files- tools/train/train.py +63 -42
tools/train/train.py
CHANGED
|
@@ -65,7 +65,7 @@ class ModelArguments:
|
|
| 65 |
config_name: Optional[str] = field(
|
| 66 |
default=None,
|
| 67 |
metadata={
|
| 68 |
-
"help": "Pretrained config name or path if not the same as
|
| 69 |
},
|
| 70 |
)
|
| 71 |
tokenizer_name: Optional[str] = field(
|
|
@@ -77,7 +77,7 @@ class ModelArguments:
|
|
| 77 |
dtype: Optional[str] = field(
|
| 78 |
default="float32",
|
| 79 |
metadata={
|
| 80 |
-
"help": "Floating-point format in which the
|
| 81 |
},
|
| 82 |
)
|
| 83 |
|
|
@@ -106,11 +106,15 @@ class DataTrainingArguments:
|
|
| 106 |
)
|
| 107 |
train_file: Optional[str] = field(
|
| 108 |
default=None,
|
| 109 |
-
metadata={
|
|
|
|
|
|
|
| 110 |
)
|
| 111 |
validation_file: Optional[str] = field(
|
| 112 |
default=None,
|
| 113 |
-
metadata={
|
|
|
|
|
|
|
| 114 |
)
|
| 115 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 116 |
streaming: Optional[bool] = field(
|
|
@@ -132,15 +136,13 @@ class DataTrainingArguments:
|
|
| 132 |
max_train_samples: Optional[int] = field(
|
| 133 |
default=None,
|
| 134 |
metadata={
|
| 135 |
-
"help": "For debugging purposes or quicker training, truncate the number of training examples
|
| 136 |
-
"value if set."
|
| 137 |
},
|
| 138 |
)
|
| 139 |
max_eval_samples: Optional[int] = field(
|
| 140 |
default=None,
|
| 141 |
metadata={
|
| 142 |
-
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples
|
| 143 |
-
"value if set."
|
| 144 |
},
|
| 145 |
)
|
| 146 |
preprocessing_num_workers: Optional[int] = field(
|
|
@@ -191,42 +193,42 @@ class TrainingArguments:
|
|
| 191 |
|
| 192 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
| 193 |
do_eval: bool = field(
|
| 194 |
-
default=False, metadata={"help": "Whether to run eval on the
|
| 195 |
)
|
| 196 |
|
| 197 |
per_device_train_batch_size: int = field(
|
| 198 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
| 199 |
)
|
| 200 |
per_device_eval_batch_size: int = field(
|
| 201 |
-
default=8, metadata={"help": "Batch size per GPU/TPU
|
| 202 |
)
|
| 203 |
|
| 204 |
gradient_accumulation_steps: int = field(
|
| 205 |
default=1,
|
| 206 |
metadata={
|
| 207 |
-
"help": "Number of updates steps to accumulate before performing
|
| 208 |
},
|
| 209 |
)
|
| 210 |
|
| 211 |
learning_rate: float = field(
|
| 212 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
| 213 |
)
|
| 214 |
-
|
| 215 |
-
default=
|
| 216 |
-
metadata={
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
default=False,
|
| 220 |
-
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
| 221 |
)
|
| 222 |
weight_decay: float = field(
|
| 223 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
| 224 |
)
|
| 225 |
-
|
| 226 |
-
default=0.9,
|
|
|
|
| 227 |
)
|
| 228 |
-
|
| 229 |
-
default=0.999,
|
|
|
|
| 230 |
)
|
| 231 |
adam_epsilon: float = field(
|
| 232 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
@@ -234,6 +236,16 @@ class TrainingArguments:
|
|
| 234 |
max_grad_norm: float = field(
|
| 235 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 236 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
use_decay: bool = field(
|
| 238 |
default=False,
|
| 239 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
|
@@ -272,6 +284,13 @@ class TrainingArguments:
|
|
| 272 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 273 |
)
|
| 274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
class TrainState(train_state.TrainState):
|
| 277 |
dropout_rng: jnp.ndarray = None
|
|
@@ -551,29 +570,22 @@ def main():
|
|
| 551 |
return traverse_util.unflatten_dict(flat_mask)
|
| 552 |
|
| 553 |
# create adam optimizer
|
| 554 |
-
if training_args.
|
| 555 |
-
# We use the default parameters here to initialize adafactor,
|
| 556 |
-
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 557 |
-
optimizer = optax.adafactor(
|
| 558 |
-
learning_rate=learning_rate_fn,
|
| 559 |
-
weight_decay_rate=training_args.weight_decay,
|
| 560 |
-
weight_decay_mask=decay_mask_fn,
|
| 561 |
-
clipping_threshold=training_args.max_grad_norm,
|
| 562 |
-
)
|
| 563 |
-
elif training_args.distributed_shampoo:
|
| 564 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 565 |
# Notes:
|
| 566 |
-
# - mask for weight decay is not implemented
|
| 567 |
optimizer = distributed_shampoo(
|
| 568 |
learning_rate_fn,
|
| 569 |
block_size=1024, # recommended default for large LM is 1536
|
| 570 |
-
beta1=
|
| 571 |
-
beta2=
|
| 572 |
diagonal_epsilon=1e-10,
|
| 573 |
matrix_epsilon=1e-8,
|
| 574 |
-
weight_decay=
|
|
|
|
|
|
|
| 575 |
start_preconditioning_step=1001,
|
| 576 |
-
preconditioning_compute_steps=
|
| 577 |
statistics_compute_steps=1,
|
| 578 |
best_effort_shape_interpretation=True,
|
| 579 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
@@ -585,20 +597,29 @@ def main():
|
|
| 585 |
skip_preconditioning_dim_size_gt=4096,
|
| 586 |
clip_by_scaled_gradient_norm=None,
|
| 587 |
precision=jax.lax.Precision.HIGHEST,
|
| 588 |
-
best_effort_memory_usage_reduction=
|
| 589 |
)
|
| 590 |
|
| 591 |
-
|
| 592 |
optimizer = optax.adamw(
|
| 593 |
learning_rate=learning_rate_fn,
|
| 594 |
-
b1=training_args.
|
| 595 |
-
b2=training_args.
|
| 596 |
eps=training_args.adam_epsilon,
|
| 597 |
weight_decay=training_args.weight_decay
|
| 598 |
if training_args.weight_decay is not None
|
| 599 |
else 0.0,
|
| 600 |
mask=decay_mask_fn,
|
| 601 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
|
| 603 |
# add gradient accumulation
|
| 604 |
if training_args.gradient_accumulation_steps > 1:
|
|
|
|
| 65 |
config_name: Optional[str] = field(
|
| 66 |
default=None,
|
| 67 |
metadata={
|
| 68 |
+
"help": "Pretrained config name or path if not the same as model_name_or_path"
|
| 69 |
},
|
| 70 |
)
|
| 71 |
tokenizer_name: Optional[str] = field(
|
|
|
|
| 77 |
dtype: Optional[str] = field(
|
| 78 |
default="float32",
|
| 79 |
metadata={
|
| 80 |
+
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
| 81 |
},
|
| 82 |
)
|
| 83 |
|
|
|
|
| 106 |
)
|
| 107 |
train_file: Optional[str] = field(
|
| 108 |
default=None,
|
| 109 |
+
metadata={
|
| 110 |
+
"help": "The input training data file (glob & braceexpand acceptable)."
|
| 111 |
+
},
|
| 112 |
)
|
| 113 |
validation_file: Optional[str] = field(
|
| 114 |
default=None,
|
| 115 |
+
metadata={
|
| 116 |
+
"help": "An optional input evaluation data file (glob & braceexpand acceptable)."
|
| 117 |
+
},
|
| 118 |
)
|
| 119 |
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 120 |
streaming: Optional[bool] = field(
|
|
|
|
| 136 |
max_train_samples: Optional[int] = field(
|
| 137 |
default=None,
|
| 138 |
metadata={
|
| 139 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples."
|
|
|
|
| 140 |
},
|
| 141 |
)
|
| 142 |
max_eval_samples: Optional[int] = field(
|
| 143 |
default=None,
|
| 144 |
metadata={
|
| 145 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples."
|
|
|
|
| 146 |
},
|
| 147 |
)
|
| 148 |
preprocessing_num_workers: Optional[int] = field(
|
|
|
|
| 193 |
|
| 194 |
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
| 195 |
do_eval: bool = field(
|
| 196 |
+
default=False, metadata={"help": "Whether to run eval on the validation set."}
|
| 197 |
)
|
| 198 |
|
| 199 |
per_device_train_batch_size: int = field(
|
| 200 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for training."}
|
| 201 |
)
|
| 202 |
per_device_eval_batch_size: int = field(
|
| 203 |
+
default=8, metadata={"help": "Batch size per GPU/TPU/CPU for evaluation."}
|
| 204 |
)
|
| 205 |
|
| 206 |
gradient_accumulation_steps: int = field(
|
| 207 |
default=1,
|
| 208 |
metadata={
|
| 209 |
+
"help": "Number of updates steps to accumulate before performing an update pass."
|
| 210 |
},
|
| 211 |
)
|
| 212 |
|
| 213 |
learning_rate: float = field(
|
| 214 |
default=5e-5, metadata={"help": "The initial learning rate."}
|
| 215 |
)
|
| 216 |
+
optim: str = field(
|
| 217 |
+
default="distributed_shampoo",
|
| 218 |
+
metadata={
|
| 219 |
+
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
| 220 |
+
},
|
|
|
|
|
|
|
| 221 |
)
|
| 222 |
weight_decay: float = field(
|
| 223 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
| 224 |
)
|
| 225 |
+
beta1: float = field(
|
| 226 |
+
default=0.9,
|
| 227 |
+
metadata={"help": "Beta1 for adam & distributed_shampoo optimizers"},
|
| 228 |
)
|
| 229 |
+
beta2: float = field(
|
| 230 |
+
default=0.999,
|
| 231 |
+
metadata={"help": "Beta2 for adam & distributed_shampoo optimizers"},
|
| 232 |
)
|
| 233 |
adam_epsilon: float = field(
|
| 234 |
default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}
|
|
|
|
| 236 |
max_grad_norm: float = field(
|
| 237 |
default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
|
| 238 |
)
|
| 239 |
+
preconditioning_compute_steps: int = field(
|
| 240 |
+
default=10, metadata={"help": "Number of steps to update preconditioner."}
|
| 241 |
+
)
|
| 242 |
+
optim_quantized: bool = field(
|
| 243 |
+
default=False,
|
| 244 |
+
metadat={
|
| 245 |
+
"help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
|
| 246 |
+
},
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
use_decay: bool = field(
|
| 250 |
default=False,
|
| 251 |
metadata={"help": "Whether to use decay in the learning rate scheduler."},
|
|
|
|
| 284 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 285 |
)
|
| 286 |
|
| 287 |
+
def __post_init__(self):
|
| 288 |
+
assert self.optim in [
|
| 289 |
+
"distributed_shampoo",
|
| 290 |
+
"adam",
|
| 291 |
+
"adafactor",
|
| 292 |
+
], f"Selected optimizer not supported: {self.optim}"
|
| 293 |
+
|
| 294 |
|
| 295 |
class TrainState(train_state.TrainState):
|
| 296 |
dropout_rng: jnp.ndarray = None
|
|
|
|
| 570 |
return traverse_util.unflatten_dict(flat_mask)
|
| 571 |
|
| 572 |
# create adam optimizer
|
| 573 |
+
if training_args.optim == "distributed_shampoo":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 575 |
# Notes:
|
| 576 |
+
# - mask for weight decay is not implemented
|
| 577 |
optimizer = distributed_shampoo(
|
| 578 |
learning_rate_fn,
|
| 579 |
block_size=1024, # recommended default for large LM is 1536
|
| 580 |
+
beta1=training_args.beta1,
|
| 581 |
+
beta2=training_args.beta2,
|
| 582 |
diagonal_epsilon=1e-10,
|
| 583 |
matrix_epsilon=1e-8,
|
| 584 |
+
weight_decay=training_args.weight_decay
|
| 585 |
+
if training_args.weight_decay is not None
|
| 586 |
+
else 0.0,
|
| 587 |
start_preconditioning_step=1001,
|
| 588 |
+
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 589 |
statistics_compute_steps=1,
|
| 590 |
best_effort_shape_interpretation=True,
|
| 591 |
graft_type=GraftingType.RMSPROP_NORMALIZED,
|
|
|
|
| 597 |
skip_preconditioning_dim_size_gt=4096,
|
| 598 |
clip_by_scaled_gradient_norm=None,
|
| 599 |
precision=jax.lax.Precision.HIGHEST,
|
| 600 |
+
best_effort_memory_usage_reduction=training_args.optim_quantized,
|
| 601 |
)
|
| 602 |
|
| 603 |
+
elif training_args.optim == "adam":
|
| 604 |
optimizer = optax.adamw(
|
| 605 |
learning_rate=learning_rate_fn,
|
| 606 |
+
b1=training_args.beta1,
|
| 607 |
+
b2=training_args.beta2,
|
| 608 |
eps=training_args.adam_epsilon,
|
| 609 |
weight_decay=training_args.weight_decay
|
| 610 |
if training_args.weight_decay is not None
|
| 611 |
else 0.0,
|
| 612 |
mask=decay_mask_fn,
|
| 613 |
)
|
| 614 |
+
elif training_args.optim == "adafactor":
|
| 615 |
+
# We use the default parameters here to initialize adafactor,
|
| 616 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 617 |
+
optimizer = optax.adafactor(
|
| 618 |
+
learning_rate=learning_rate_fn,
|
| 619 |
+
weight_decay_rate=training_args.weight_decay,
|
| 620 |
+
weight_decay_mask=decay_mask_fn,
|
| 621 |
+
clipping_threshold=training_args.max_grad_norm,
|
| 622 |
+
)
|
| 623 |
|
| 624 |
# add gradient accumulation
|
| 625 |
if training_args.gradient_accumulation_steps > 1:
|