Spaces:
Running
Running
feat(train): simplify tokenizer loading
Browse files- tools/train/train.py +12 -10
tools/train/train.py
CHANGED
|
@@ -55,7 +55,7 @@ from dalle_mini.model import (
|
|
| 55 |
)
|
| 56 |
|
| 57 |
cc.initialize_cache(
|
| 58 |
-
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2**30
|
| 59 |
)
|
| 60 |
|
| 61 |
|
|
@@ -104,6 +104,11 @@ class ModelArguments:
|
|
| 104 |
state_artifact: str = field(init=False)
|
| 105 |
|
| 106 |
def __post_init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
if self.restore_state:
|
| 108 |
assert self.model_name_or_path is not None and (
|
| 109 |
"/model-" in self.model_name_or_path
|
|
@@ -511,15 +516,9 @@ def main():
|
|
| 511 |
)
|
| 512 |
|
| 513 |
# Load tokenizer
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
)
|
| 518 |
-
else:
|
| 519 |
-
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 520 |
-
model_args.model_name_or_path,
|
| 521 |
-
use_fast=True,
|
| 522 |
-
)
|
| 523 |
|
| 524 |
# get PartitionSpec for model params (required to be a dict)
|
| 525 |
param_spec = set_partitions(model.params)
|
|
@@ -532,6 +531,9 @@ def main():
|
|
| 532 |
|
| 533 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
| 534 |
|
|
|
|
|
|
|
|
|
|
| 535 |
# Initialize our training
|
| 536 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
| 537 |
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
cc.initialize_cache(
|
| 58 |
+
"/home/boris/dalle-mini/jax_cache", max_cache_size_bytes=5 * 2 ** 30
|
| 59 |
)
|
| 60 |
|
| 61 |
|
|
|
|
| 104 |
state_artifact: str = field(init=False)
|
| 105 |
|
| 106 |
def __post_init__(self):
|
| 107 |
+
if self.tokenizer_name is None:
|
| 108 |
+
self.tokenizer_name == self.model_name_or_path
|
| 109 |
+
assert (
|
| 110 |
+
self.tokenizer_name is not None
|
| 111 |
+
), "Tokenizer name or model name/path needs to be specified"
|
| 112 |
if self.restore_state:
|
| 113 |
assert self.model_name_or_path is not None and (
|
| 114 |
"/model-" in self.model_name_or_path
|
|
|
|
| 516 |
)
|
| 517 |
|
| 518 |
# Load tokenizer
|
| 519 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 520 |
+
model_args.tokenizer_name, use_fast=True
|
| 521 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
# get PartitionSpec for model params (required to be a dict)
|
| 524 |
param_spec = set_partitions(model.params)
|
|
|
|
| 531 |
|
| 532 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
| 533 |
|
| 534 |
+
# no dropout (hardcoded)
|
| 535 |
+
model.config.dropout = 0.0
|
| 536 |
+
|
| 537 |
# Initialize our training
|
| 538 |
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
| 539 |
|