Spaces:
Runtime error
Runtime error
fix: shampoo -> distributed shampoo
Browse files- tools/train/train.py +4 -4
tools/train/train.py
CHANGED
|
@@ -214,11 +214,11 @@ class TrainingArguments:
|
|
| 214 |
)
|
| 215 |
adafactor: bool = field(
|
| 216 |
default=False,
|
| 217 |
-
metadata={"help": "
|
| 218 |
)
|
| 219 |
-
|
| 220 |
default=False,
|
| 221 |
-
metadata={"help": "
|
| 222 |
)
|
| 223 |
weight_decay: float = field(
|
| 224 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
|
@@ -566,7 +566,7 @@ def main():
|
|
| 566 |
weight_decay_mask=decay_mask_fn,
|
| 567 |
clipping_threshold=training_args.max_grad_norm,
|
| 568 |
)
|
| 569 |
-
elif training_args.
|
| 570 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 571 |
# Notes:
|
| 572 |
# - mask for weight decay is not implemented but we don't use it anyway
|
|
|
|
| 214 |
)
|
| 215 |
adafactor: bool = field(
|
| 216 |
default=False,
|
| 217 |
+
metadata={"help": "Use Adafactor instead of AdamW."},
|
| 218 |
)
|
| 219 |
+
distributed_shampoo: bool = field(
|
| 220 |
default=False,
|
| 221 |
+
metadata={"help": "Use Distributed Shampoo optimizer instead of AdamW."},
|
| 222 |
)
|
| 223 |
weight_decay: float = field(
|
| 224 |
default=None, metadata={"help": "Weight decay if we apply some."}
|
|
|
|
| 566 |
weight_decay_mask=decay_mask_fn,
|
| 567 |
clipping_threshold=training_args.max_grad_norm,
|
| 568 |
)
|
| 569 |
+
elif training_args.distributed_shampoo:
|
| 570 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 571 |
# Notes:
|
| 572 |
# - mask for weight decay is not implemented but we don't use it anyway
|