Spaces:
Running
Running
feat(train): split artifact into model/state
Browse files- tools/train/train.py +92 -97
tools/train/train.py
CHANGED
|
@@ -88,6 +88,24 @@ class ModelArguments:
|
|
| 88 |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
| 89 |
},
|
| 90 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
@dataclass
|
|
@@ -319,11 +337,6 @@ class TrainingArguments:
|
|
| 319 |
},
|
| 320 |
)
|
| 321 |
|
| 322 |
-
resume_from_checkpoint: Optional[str] = field(
|
| 323 |
-
default=None,
|
| 324 |
-
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
wandb_entity: Optional[str] = field(
|
| 328 |
default=None,
|
| 329 |
metadata={"help": "The wandb entity to use (for teams)."},
|
|
@@ -349,6 +362,8 @@ class TrainingArguments:
|
|
| 349 |
},
|
| 350 |
)
|
| 351 |
|
|
|
|
|
|
|
| 352 |
def __post_init__(self):
|
| 353 |
assert self.optim in [
|
| 354 |
"distributed_shampoo",
|
|
@@ -470,62 +485,40 @@ def main():
|
|
| 470 |
config=parser.parse_args(),
|
| 471 |
)
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
artifact_dir = artifact.download()
|
| 479 |
|
| 480 |
-
|
|
|
|
| 481 |
model = DalleBart.from_pretrained(
|
| 482 |
-
|
|
|
|
|
|
|
| 483 |
dtype=getattr(jnp, model_args.dtype),
|
| 484 |
abstract_init=True,
|
| 485 |
load_on_cpu=True,
|
| 486 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
-
|
|
|
|
| 489 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 490 |
-
|
| 491 |
-
use_fast=True,
|
| 492 |
)
|
| 493 |
-
|
| 494 |
else:
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
config = None
|
| 500 |
-
|
| 501 |
-
# Load or create new model
|
| 502 |
-
if model_args.model_name_or_path:
|
| 503 |
-
model = DalleBart.from_pretrained(
|
| 504 |
-
model_args.model_name_or_path,
|
| 505 |
-
config=config,
|
| 506 |
-
seed=training_args.seed_model,
|
| 507 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 508 |
-
abstract_init=True,
|
| 509 |
-
load_on_cpu=True,
|
| 510 |
-
)
|
| 511 |
-
else:
|
| 512 |
-
model = DalleBart(
|
| 513 |
-
config,
|
| 514 |
-
seed=training_args.seed_model,
|
| 515 |
-
dtype=getattr(jnp, model_args.dtype),
|
| 516 |
-
load_on_cpu=True,
|
| 517 |
-
)
|
| 518 |
-
|
| 519 |
-
# Load tokenizer
|
| 520 |
-
if model_args.tokenizer_name is not None:
|
| 521 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 522 |
-
model_args.tokenizer_name, use_fast=True
|
| 523 |
-
)
|
| 524 |
-
else:
|
| 525 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 526 |
-
model_args.model_name_or_path,
|
| 527 |
-
use_fast=True,
|
| 528 |
-
)
|
| 529 |
|
| 530 |
# get PartitionSpec for model params (required to be a dict)
|
| 531 |
param_spec = set_partitions(model.params)
|
|
@@ -698,7 +691,7 @@ def main():
|
|
| 698 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 699 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 700 |
|
| 701 |
-
#
|
| 702 |
state_spec = TrainState(
|
| 703 |
params=param_spec,
|
| 704 |
opt_state=opt_state_spec,
|
|
@@ -713,7 +706,7 @@ def main():
|
|
| 713 |
|
| 714 |
# create training state
|
| 715 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 716 |
-
if
|
| 717 |
|
| 718 |
def init_state(params):
|
| 719 |
return TrainState.create(
|
|
@@ -731,6 +724,13 @@ def main():
|
|
| 731 |
)(model.params)
|
| 732 |
|
| 733 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
# restore opt_state
|
| 735 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 736 |
opt_state = from_bytes(opt_state_shape, f.read())
|
|
@@ -998,51 +998,46 @@ def main():
|
|
| 998 |
f,
|
| 999 |
)
|
| 1000 |
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
|
| 1018 |
-
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
str(Path(training_args.output_dir) / "opt_state.msgpack")
|
| 1040 |
-
)
|
| 1041 |
-
artifact.add_file(
|
| 1042 |
-
str(Path(training_args.output_dir) / "training_state.json")
|
| 1043 |
)
|
| 1044 |
-
|
| 1045 |
-
wandb.run.log_artifact(artifact)
|
| 1046 |
|
| 1047 |
# init variables
|
| 1048 |
last_time = time.perf_counter()
|
|
|
|
| 88 |
"help": "Floating-point format in which the computations will be performed (not the model weights). Choose one of `[float32, float16, bfloat16]`."
|
| 89 |
},
|
| 90 |
)
|
| 91 |
+
restore_state: Optional[bool] = field(
|
| 92 |
+
default=False,
|
| 93 |
+
metadata={
|
| 94 |
+
"help": "Restore optimizer and training state associated with a wandb checkpoint."
|
| 95 |
+
},
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
state_artifact: str = field(init=False)
|
| 99 |
+
|
| 100 |
+
def __post_init__(self):
|
| 101 |
+
if self.restore_state:
|
| 102 |
+
assert (
|
| 103 |
+
"/model-" in self.model_name_or_path
|
| 104 |
+
), "Restoring state only available with W&B artifact reference"
|
| 105 |
+
self.state_artifact = self.model_name_or_path.replace(
|
| 106 |
+
"/model-", "/state-", 1
|
| 107 |
+
)
|
| 108 |
+
raise ValueError("Need a dataset repository or path.")
|
| 109 |
|
| 110 |
|
| 111 |
@dataclass
|
|
|
|
| 337 |
},
|
| 338 |
)
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
wandb_entity: Optional[str] = field(
|
| 341 |
default=None,
|
| 342 |
metadata={"help": "The wandb entity to use (for teams)."},
|
|
|
|
| 362 |
},
|
| 363 |
)
|
| 364 |
|
| 365 |
+
dp_devices: int = field(init=False)
|
| 366 |
+
|
| 367 |
def __post_init__(self):
|
| 368 |
assert self.optim in [
|
| 369 |
"distributed_shampoo",
|
|
|
|
| 485 |
config=parser.parse_args(),
|
| 486 |
)
|
| 487 |
|
| 488 |
+
# Set up our new model config
|
| 489 |
+
if model_args.config_name:
|
| 490 |
+
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
| 491 |
+
else:
|
| 492 |
+
config = None
|
|
|
|
| 493 |
|
| 494 |
+
# Load or create new model
|
| 495 |
+
if model_args.model_name_or_path:
|
| 496 |
model = DalleBart.from_pretrained(
|
| 497 |
+
model_args.model_name_or_path,
|
| 498 |
+
config=config,
|
| 499 |
+
seed=training_args.seed_model,
|
| 500 |
dtype=getattr(jnp, model_args.dtype),
|
| 501 |
abstract_init=True,
|
| 502 |
load_on_cpu=True,
|
| 503 |
)
|
| 504 |
+
else:
|
| 505 |
+
model = DalleBart(
|
| 506 |
+
config,
|
| 507 |
+
seed=training_args.seed_model,
|
| 508 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 509 |
+
load_on_cpu=True,
|
| 510 |
+
)
|
| 511 |
|
| 512 |
+
# Load tokenizer
|
| 513 |
+
if model_args.tokenizer_name is not None:
|
| 514 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 515 |
+
model_args.tokenizer_name, use_fast=True
|
|
|
|
| 516 |
)
|
|
|
|
| 517 |
else:
|
| 518 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 519 |
+
model_args.model_name_or_path,
|
| 520 |
+
use_fast=True,
|
| 521 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
# get PartitionSpec for model params (required to be a dict)
|
| 524 |
param_spec = set_partitions(model.params)
|
|
|
|
| 691 |
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 692 |
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 693 |
|
| 694 |
+
# define state spec
|
| 695 |
state_spec = TrainState(
|
| 696 |
params=param_spec,
|
| 697 |
opt_state=opt_state_spec,
|
|
|
|
| 706 |
|
| 707 |
# create training state
|
| 708 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 709 |
+
if not model_args.restore_state:
|
| 710 |
|
| 711 |
def init_state(params):
|
| 712 |
return TrainState.create(
|
|
|
|
| 724 |
)(model.params)
|
| 725 |
|
| 726 |
else:
|
| 727 |
+
# get state files from artifact
|
| 728 |
+
if jax.process_index() == 0:
|
| 729 |
+
artifact = wandb.run.use_artifact(model_args.state_artifact)
|
| 730 |
+
else:
|
| 731 |
+
artifact = wandb.Api().artifact(model_args.state_artifact)
|
| 732 |
+
artifact_dir = artifact.download()
|
| 733 |
+
|
| 734 |
# restore opt_state
|
| 735 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 736 |
opt_state = from_bytes(opt_state_shape, f.read())
|
|
|
|
| 998 |
f,
|
| 999 |
)
|
| 1000 |
|
| 1001 |
+
# save to W&B
|
| 1002 |
+
if training_args.log_model:
|
| 1003 |
+
# save some space
|
| 1004 |
+
c = wandb.wandb_sdk.wandb_artifacts.get_artifacts_cache()
|
| 1005 |
+
c.cleanup(wandb.util.from_human_size("10GB"))
|
| 1006 |
+
|
| 1007 |
+
metadata = dict(state_dict)
|
| 1008 |
+
metadata["num_params"] = num_params
|
| 1009 |
+
if eval_metrics is not None:
|
| 1010 |
+
metadata["eval"] = eval_metrics
|
| 1011 |
+
|
| 1012 |
+
# create model artifact
|
| 1013 |
+
artifact = wandb.Artifact(
|
| 1014 |
+
name=f"model-{wandb.run.id}",
|
| 1015 |
+
type="DalleBart_model",
|
| 1016 |
+
metadata=metadata,
|
| 1017 |
+
)
|
| 1018 |
+
for filename in [
|
| 1019 |
+
"config.json",
|
| 1020 |
+
"flax_model.msgpack",
|
| 1021 |
+
"merges.txt",
|
| 1022 |
+
"special_tokens_map.json",
|
| 1023 |
+
"tokenizer.json",
|
| 1024 |
+
"tokenizer_config.json",
|
| 1025 |
+
"vocab.json",
|
| 1026 |
+
]:
|
| 1027 |
+
artifact.add_file(f"{Path(training_args.output_dir) / filename}")
|
| 1028 |
+
wandb.run.log_artifact(artifact)
|
| 1029 |
+
|
| 1030 |
+
# create state artifact
|
| 1031 |
+
artifact_state = wandb.Artifact(
|
| 1032 |
+
name=f"state-{wandb.run.id}",
|
| 1033 |
+
type="DalleBart_state",
|
| 1034 |
+
metadata=metadata,
|
| 1035 |
+
)
|
| 1036 |
+
for filename in ["opt_state.msgpack", "training_state.json"]:
|
| 1037 |
+
artifact_state.add_file(
|
| 1038 |
+
f"{Path(training_args.output_dir) / filename}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1039 |
)
|
| 1040 |
+
wandb.run.log_artifact(artifact_state)
|
|
|
|
| 1041 |
|
| 1042 |
# init variables
|
| 1043 |
last_time = time.perf_counter()
|