Spaces:
Running
Running
Pedro Cuenca
commited on
Commit
·
bb3f53e
1
Parent(s):
08dd098
Update `resume_from_checkpoint` to use `from_pretrained`.
Browse files- tools/train/train.py +3 -9
tools/train/train.py
CHANGED
|
@@ -434,22 +434,16 @@ def main():
|
|
| 434 |
)
|
| 435 |
|
| 436 |
if training_args.resume_from_checkpoint is not None:
|
| 437 |
-
if jax.process_index() == 0:
|
| 438 |
-
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
| 439 |
-
else:
|
| 440 |
-
artifact = wandb.Api().artifact(training_args.resume_from_checkpoint)
|
| 441 |
-
artifact_dir = artifact.download()
|
| 442 |
-
|
| 443 |
# load model
|
| 444 |
model = DalleBart.from_pretrained(
|
| 445 |
-
|
| 446 |
)
|
| 447 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 448 |
print(model.params)
|
| 449 |
|
| 450 |
# load tokenizer
|
| 451 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 452 |
-
|
| 453 |
use_fast=True,
|
| 454 |
)
|
| 455 |
|
|
@@ -624,7 +618,7 @@ def main():
|
|
| 624 |
if training_args.resume_from_checkpoint is not None:
|
| 625 |
# restore optimizer state and other parameters
|
| 626 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 627 |
-
state = state.restore_state(
|
| 628 |
|
| 629 |
# label smoothed cross entropy
|
| 630 |
def loss_fn(logits, labels):
|
|
|
|
| 434 |
)
|
| 435 |
|
| 436 |
if training_args.resume_from_checkpoint is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
# load model
|
| 438 |
model = DalleBart.from_pretrained(
|
| 439 |
+
training_args.resume_from_checkpoint, dtype=getattr(jnp, model_args.dtype), abstract_init=True
|
| 440 |
)
|
| 441 |
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 442 |
print(model.params)
|
| 443 |
|
| 444 |
# load tokenizer
|
| 445 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 446 |
+
model.config.resolved_name_or_path,
|
| 447 |
use_fast=True,
|
| 448 |
)
|
| 449 |
|
|
|
|
| 618 |
if training_args.resume_from_checkpoint is not None:
|
| 619 |
# restore optimizer state and other parameters
|
| 620 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 621 |
+
state = state.restore_state(model.config.resolved_name_or_path)
|
| 622 |
|
| 623 |
# label smoothed cross entropy
|
| 624 |
def loss_fn(logits, labels):
|