Spaces:
Runtime error
Runtime error
fix: actually replace state
Browse files- dev/seq2seq/run_seq2seq_flax.py +10 -11
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -435,18 +435,16 @@ def main():
|
|
| 435 |
|
| 436 |
def restore_state(state, artifact_dir):
|
| 437 |
# restore optimizer state
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
opt_state = from_bytes(state.opt_state, f.read())
|
| 441 |
-
state.replace(opt_state=opt_state)
|
| 442 |
|
| 443 |
# restore steps
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
|
| 451 |
if model_args.from_checkpoint is not None:
|
| 452 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
|
@@ -668,7 +666,8 @@ def main():
|
|
| 668 |
)
|
| 669 |
if model_args.from_checkpoint is not None:
|
| 670 |
# restore optimizer state, step and optimizer_step
|
| 671 |
-
restore_state(state, artifact_dir)
|
|
|
|
| 672 |
|
| 673 |
# label smoothed cross entropy
|
| 674 |
def loss_fn(logits, labels):
|
|
|
|
| 435 |
|
| 436 |
def restore_state(state, artifact_dir):
|
| 437 |
# restore optimizer state
|
| 438 |
+
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
| 439 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
|
|
|
|
|
|
| 440 |
|
| 441 |
# restore steps
|
| 442 |
+
with (Path(artifact_dir) / 'training_state.json').open('r') as f:
|
| 443 |
+
training_state = json.load(f)
|
| 444 |
+
step = training_state['step']
|
| 445 |
+
optimizer_step = step // training_args.gradient_accumulation_steps
|
| 446 |
+
|
| 447 |
+
return step, optimizer_step, opt_state
|
| 448 |
|
| 449 |
if model_args.from_checkpoint is not None:
|
| 450 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
|
|
|
| 666 |
)
|
| 667 |
if model_args.from_checkpoint is not None:
|
| 668 |
# restore optimizer state, step and optimizer_step
|
| 669 |
+
step, optimizer_step, opt_state = restore_state(state, artifact_dir)
|
| 670 |
+
state = state.replace(step=step, optimizer_step=optimizer_step, opt_state=opt_state)
|
| 671 |
|
| 672 |
# label smoothed cross entropy
|
| 673 |
def loss_fn(logits, labels):
|