fix(train): overwrite dropout only when specified
Browse files- tools/train/train.py +9 -7
tools/train/train.py
CHANGED
|
@@ -131,7 +131,7 @@ class ModelArguments:
|
|
| 131 |
), "Restoring state only available with W&B artifact reference"
|
| 132 |
|
| 133 |
def get_metadata(self):
|
| 134 |
-
if ":" in self.model_name_or_path:
|
| 135 |
if jax.process_index() == 0:
|
| 136 |
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
| 137 |
else:
|
|
@@ -685,12 +685,16 @@ def main():
|
|
| 685 |
)
|
| 686 |
|
| 687 |
# Set up our new model config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
if model_args.config_name:
|
| 689 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
| 690 |
config.gradient_checkpointing = training_args.gradient_checkpointing
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
config.attention_dropout = model_args.attention_dropout
|
| 694 |
else:
|
| 695 |
config = None
|
| 696 |
|
|
@@ -703,9 +707,7 @@ def main():
|
|
| 703 |
dtype=getattr(jnp, model_args.dtype),
|
| 704 |
_do_init=False, # we overwrite them with loaded checkpoint
|
| 705 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
| 706 |
-
|
| 707 |
-
activation_dropout=model_args.activation_dropout,
|
| 708 |
-
attention_dropout=model_args.attention_dropout,
|
| 709 |
)
|
| 710 |
else:
|
| 711 |
model = DalleBart(
|
|
|
|
| 131 |
), "Restoring state only available with W&B artifact reference"
|
| 132 |
|
| 133 |
def get_metadata(self):
|
| 134 |
+
if self.model_name_or_path is not None and ":" in self.model_name_or_path:
|
| 135 |
if jax.process_index() == 0:
|
| 136 |
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
| 137 |
else:
|
|
|
|
| 685 |
)
|
| 686 |
|
| 687 |
# Set up our new model config
|
| 688 |
+
config_args = {
|
| 689 |
+
k: getattr(model_args, k)
|
| 690 |
+
for k in ["dropout", "activation_dropout", "attention_dropout"]
|
| 691 |
+
if getattr(model_args, k) is not None
|
| 692 |
+
}
|
| 693 |
if model_args.config_name:
|
| 694 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
| 695 |
config.gradient_checkpointing = training_args.gradient_checkpointing
|
| 696 |
+
for k, v in config_args.items():
|
| 697 |
+
setattr(config, k, v)
|
|
|
|
| 698 |
else:
|
| 699 |
config = None
|
| 700 |
|
|
|
|
| 707 |
dtype=getattr(jnp, model_args.dtype),
|
| 708 |
_do_init=False, # we overwrite them with loaded checkpoint
|
| 709 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
| 710 |
+
**config_args,
|
|
|
|
|
|
|
| 711 |
)
|
| 712 |
else:
|
| 713 |
model = DalleBart(
|