Spaces:
Runtime error
Runtime error
fix: weight decay Adam + speed logging
Browse files- tools/train/train.py +9 -5
tools/train/train.py
CHANGED
|
@@ -353,10 +353,12 @@ class MetricsLogger:
|
|
| 353 |
# timing metrics
|
| 354 |
new_step = int(unreplicate(state.step))
|
| 355 |
new_time = time.perf_counter()
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
| 360 |
|
| 361 |
@staticmethod
|
| 362 |
def log(metrics, step=None, prefix=None):
|
|
@@ -599,7 +601,9 @@ def main():
|
|
| 599 |
b1=training_args.adam_beta1,
|
| 600 |
b2=training_args.adam_beta2,
|
| 601 |
eps=training_args.adam_epsilon,
|
| 602 |
-
weight_decay=training_args.weight_decay
|
|
|
|
|
|
|
| 603 |
mask=decay_mask_fn,
|
| 604 |
)
|
| 605 |
|
|
|
|
| 353 |
# timing metrics
|
| 354 |
new_step = int(unreplicate(state.step))
|
| 355 |
new_time = time.perf_counter()
|
| 356 |
+
if new_step > self.step:
|
| 357 |
+
time_per_step = (new_time - self.time) / (new_step - self.step)
|
| 358 |
+
self.step = new_step
|
| 359 |
+
self.time = new_time
|
| 360 |
+
state_dict["time_per_step"] = time_per_step
|
| 361 |
+
return {**metrics, **state_dict}
|
| 362 |
|
| 363 |
@staticmethod
|
| 364 |
def log(metrics, step=None, prefix=None):
|
|
|
|
| 601 |
b1=training_args.adam_beta1,
|
| 602 |
b2=training_args.adam_beta2,
|
| 603 |
eps=training_args.adam_epsilon,
|
| 604 |
+
weight_decay=training_args.weight_decay
|
| 605 |
+
if training_args.weight_decay is not None
|
| 606 |
+
else 0.0,
|
| 607 |
mask=decay_mask_fn,
|
| 608 |
)
|
| 609 |
|