Spaces:
Running
Running
feat: use pretrained weights
Browse files- dev/seq2seq/run_seq2seq_flax.py +52 -19
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -68,6 +68,12 @@ class ModelArguments:
|
|
| 68 |
"Don't set if you want to train a model from scratch."
|
| 69 |
},
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
image_vocab_size: Optional[int] = field(
|
| 72 |
default=None,
|
| 73 |
metadata={"help": "Vocab size of image encoder"},
|
|
@@ -82,9 +88,11 @@ class ModelArguments:
|
|
| 82 |
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
|
| 83 |
},
|
| 84 |
)
|
| 85 |
-
normalize_text: bool = field(
|
| 86 |
-
default=
|
| 87 |
-
metadata={
|
|
|
|
|
|
|
| 88 |
)
|
| 89 |
dtype: Optional[str] = field(
|
| 90 |
default="float32",
|
|
@@ -125,8 +133,9 @@ class DataTrainingArguments:
|
|
| 125 |
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
| 126 |
},
|
| 127 |
)
|
|
|
|
| 128 |
streaming: bool = field(
|
| 129 |
-
default=
|
| 130 |
metadata={"help": "Whether to stream the dataset."},
|
| 131 |
)
|
| 132 |
use_auth_token: bool = field(
|
|
@@ -283,9 +292,9 @@ class TrainingArguments:
|
|
| 283 |
},
|
| 284 |
)
|
| 285 |
|
| 286 |
-
|
| 287 |
default=None,
|
| 288 |
-
metadata={"help": "
|
| 289 |
)
|
| 290 |
|
| 291 |
|
|
@@ -460,8 +469,8 @@ def main():
|
|
| 460 |
config=parser.parse_args(),
|
| 461 |
)
|
| 462 |
|
| 463 |
-
if training_args.
|
| 464 |
-
artifact = wandb.run.use_artifact(training_args.
|
| 465 |
artifact_dir = artifact.download()
|
| 466 |
|
| 467 |
# load model
|
|
@@ -476,9 +485,20 @@ def main():
|
|
| 476 |
else:
|
| 477 |
# Set up our new model config
|
| 478 |
# TODO: simplify with custom config class
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
# we append decoder bos to image vocab
|
| 483 |
config.decoder_start_token_id = config.image_vocab_size
|
| 484 |
# ensure we don't generate bos (in addition to decoder start token)
|
|
@@ -487,8 +507,8 @@ def main():
|
|
| 487 |
config.forced_eos_token_id = None # we don't need this token
|
| 488 |
|
| 489 |
config.tie_word_embeddings = False
|
| 490 |
-
config.min_length =
|
| 491 |
-
config.max_length =
|
| 492 |
|
| 493 |
# below tokens need to be set to avoid error during generation (converted to jnp.array)
|
| 494 |
# they are not expected to be used and are set to unreachable token id
|
|
@@ -497,12 +517,25 @@ def main():
|
|
| 497 |
config.eos_token_id = config.image_vocab_size + 1
|
| 498 |
|
| 499 |
# save whether we normalize the text
|
| 500 |
-
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
-
#
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
# Load tokenizer
|
| 508 |
if model_args.tokenizer_name is not None:
|
|
@@ -741,7 +774,7 @@ def main():
|
|
| 741 |
tx=optimizer,
|
| 742 |
dropout_rng=dropout_rng,
|
| 743 |
)
|
| 744 |
-
if training_args.
|
| 745 |
# restore optimizer state and other parameters
|
| 746 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 747 |
state = state.restore_state(artifact_dir)
|
|
|
|
| 68 |
"Don't set if you want to train a model from scratch."
|
| 69 |
},
|
| 70 |
)
|
| 71 |
+
config_name: Optional[str] = field(
|
| 72 |
+
default=None,
|
| 73 |
+
metadata={
|
| 74 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
| 75 |
+
},
|
| 76 |
+
)
|
| 77 |
image_vocab_size: Optional[int] = field(
|
| 78 |
default=None,
|
| 79 |
metadata={"help": "Vocab size of image encoder"},
|
|
|
|
| 88 |
"help": "Pretrained tokenizer name or path if not the same as model_name_or_path"
|
| 89 |
},
|
| 90 |
)
|
| 91 |
+
normalize_text: Optional[bool] = field(
|
| 92 |
+
default=None,
|
| 93 |
+
metadata={
|
| 94 |
+
"help": "Whether to normalize text or not. By default, we refer to base model or don't normalize for new models."
|
| 95 |
+
},
|
| 96 |
)
|
| 97 |
dtype: Optional[str] = field(
|
| 98 |
default="float32",
|
|
|
|
| 133 |
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
|
| 134 |
},
|
| 135 |
)
|
| 136 |
+
# data loading should not be a bottleneck so we use "streaming" mode by default
|
| 137 |
streaming: bool = field(
|
| 138 |
+
default=True,
|
| 139 |
metadata={"help": "Whether to stream the dataset."},
|
| 140 |
)
|
| 141 |
use_auth_token: bool = field(
|
|
|
|
| 292 |
},
|
| 293 |
)
|
| 294 |
|
| 295 |
+
resume_from_checkpoint: Optional[str] = field(
|
| 296 |
default=None,
|
| 297 |
+
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 298 |
)
|
| 299 |
|
| 300 |
|
|
|
|
| 469 |
config=parser.parse_args(),
|
| 470 |
)
|
| 471 |
|
| 472 |
+
if training_args.resume_from_checkpoint is not None:
|
| 473 |
+
artifact = wandb.run.use_artifact(training_args.resume_from_checkpoint)
|
| 474 |
artifact_dir = artifact.download()
|
| 475 |
|
| 476 |
# load model
|
|
|
|
| 485 |
else:
|
| 486 |
# Set up our new model config
|
| 487 |
# TODO: simplify with custom config class
|
| 488 |
+
if model_args.config_name:
|
| 489 |
+
config = BartConfig.from_pretrained(model_args.config_name)
|
| 490 |
+
else:
|
| 491 |
+
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
| 492 |
+
if model_args.image_vocab_size:
|
| 493 |
+
config.image_vocab_size = model_args.image_vocab_size
|
| 494 |
+
assert (
|
| 495 |
+
getattr(config, "image_vocab_size") is not None
|
| 496 |
+
), "image_vocab_size must be specified when not present in base model/config"
|
| 497 |
+
if model_args.image_length:
|
| 498 |
+
config.image_length = model_args.image_length
|
| 499 |
+
assert (
|
| 500 |
+
getattr(config, "image_length") is not None
|
| 501 |
+
), "image_length must be specified when not present in base model/config"
|
| 502 |
# we append decoder bos to image vocab
|
| 503 |
config.decoder_start_token_id = config.image_vocab_size
|
| 504 |
# ensure we don't generate bos (in addition to decoder start token)
|
|
|
|
| 507 |
config.forced_eos_token_id = None # we don't need this token
|
| 508 |
|
| 509 |
config.tie_word_embeddings = False
|
| 510 |
+
config.min_length = config.image_length + 1
|
| 511 |
+
config.max_length = config.image_length + 1
|
| 512 |
|
| 513 |
# below tokens need to be set to avoid error during generation (converted to jnp.array)
|
| 514 |
# they are not expected to be used and are set to unreachable token id
|
|
|
|
| 517 |
config.eos_token_id = config.image_vocab_size + 1
|
| 518 |
|
| 519 |
# save whether we normalize the text
|
| 520 |
+
if model_args.normalize_text is not None:
|
| 521 |
+
config.normalize_text = model_args.normalize_text
|
| 522 |
+
else:
|
| 523 |
+
config.normalize_text = getattr(config, "normalize_text", False)
|
| 524 |
|
| 525 |
+
# Load or create new model
|
| 526 |
+
if model_args.model_name_or_path:
|
| 527 |
+
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
| 528 |
+
model_args.model_name_or_path,
|
| 529 |
+
config=config,
|
| 530 |
+
seed=training_args.seed_model,
|
| 531 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 532 |
+
)
|
| 533 |
+
else:
|
| 534 |
+
model = CustomFlaxBartForConditionalGeneration(
|
| 535 |
+
config,
|
| 536 |
+
seed=training_args.seed_model,
|
| 537 |
+
dtype=getattr(jnp, model_args.dtype),
|
| 538 |
+
)
|
| 539 |
|
| 540 |
# Load tokenizer
|
| 541 |
if model_args.tokenizer_name is not None:
|
|
|
|
| 774 |
tx=optimizer,
|
| 775 |
dropout_rng=dropout_rng,
|
| 776 |
)
|
| 777 |
+
if training_args.resume_from_checkpoint is not None:
|
| 778 |
# restore optimizer state and other parameters
|
| 779 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 780 |
state = state.restore_state(artifact_dir)
|