Spaces:
Running
Running
feat: use_auth_token + seed for dataset and model
Browse files- dev/seq2seq/run_seq2seq_flax.py +35 -12
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -129,6 +129,12 @@ class DataTrainingArguments:
|
|
| 129 |
default=False,
|
| 130 |
metadata={"help": "Whether to stream the dataset."},
|
| 131 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
max_source_length: Optional[int] = field(
|
| 133 |
default=128,
|
| 134 |
metadata={
|
|
@@ -256,9 +262,18 @@ class TrainingArguments:
|
|
| 256 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
| 257 |
)
|
| 258 |
|
| 259 |
-
|
| 260 |
default=42,
|
| 261 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
push_to_hub: bool = field(
|
|
@@ -304,7 +319,9 @@ class TrainState(train_state.TrainState):
|
|
| 304 |
|
| 305 |
|
| 306 |
def data_loader(
|
| 307 |
-
|
|
|
|
|
|
|
| 308 |
):
|
| 309 |
"""
|
| 310 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
|
@@ -312,7 +329,7 @@ def data_loader(
|
|
| 312 |
"""
|
| 313 |
steps_per_epoch = len(dataset) // batch_size
|
| 314 |
|
| 315 |
-
if
|
| 316 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
| 317 |
else:
|
| 318 |
batch_idx = jnp.arange(len(dataset))
|
|
@@ -432,6 +449,7 @@ def main():
|
|
| 432 |
data_args.dataset_repo_or_path,
|
| 433 |
data_files=data_files,
|
| 434 |
streaming=data_args.streaming,
|
|
|
|
| 435 |
)
|
| 436 |
|
| 437 |
# Set up wandb run
|
|
@@ -483,7 +501,7 @@ def main():
|
|
| 483 |
|
| 484 |
# Create a custom model and initialize it randomly
|
| 485 |
model = CustomFlaxBartForConditionalGeneration(
|
| 486 |
-
config, seed=training_args.
|
| 487 |
)
|
| 488 |
|
| 489 |
# Load tokenizer
|
|
@@ -561,7 +579,14 @@ def main():
|
|
| 561 |
else train_dataset.select(range(data_args.max_train_samples))
|
| 562 |
)
|
| 563 |
if data_args.streaming:
|
| 564 |
-
train_dataset = train_dataset.shuffle(1000, training_args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
if model.config.normalize_text:
|
| 566 |
train_dataset = (
|
| 567 |
train_dataset.map(normalize_text)
|
|
@@ -627,7 +652,7 @@ def main():
|
|
| 627 |
)
|
| 628 |
|
| 629 |
# Initialize our training
|
| 630 |
-
rng = jax.random.PRNGKey(training_args.
|
| 631 |
rng, dropout_rng = jax.random.split(rng)
|
| 632 |
|
| 633 |
# Store some constant
|
|
@@ -808,7 +833,7 @@ def main():
|
|
| 808 |
if data_args.streaming:
|
| 809 |
eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
|
| 810 |
else:
|
| 811 |
-
eval_loader = data_loader(
|
| 812 |
eval_steps = (
|
| 813 |
len_eval_dataset // eval_batch_size
|
| 814 |
if len_eval_dataset is not None
|
|
@@ -927,10 +952,8 @@ def main():
|
|
| 927 |
train_dataset.set_epoch(epoch) # shuffle dataset
|
| 928 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
| 929 |
else:
|
| 930 |
-
|
| 931 |
-
train_loader = data_loader(
|
| 932 |
-
input_rng, train_dataset, train_batch_size, shuffle=True
|
| 933 |
-
)
|
| 934 |
# train
|
| 935 |
for batch in tqdm(
|
| 936 |
train_loader,
|
|
|
|
| 129 |
default=False,
|
| 130 |
metadata={"help": "Whether to stream the dataset."},
|
| 131 |
)
|
| 132 |
+
use_auth_token: bool = field(
|
| 133 |
+
default=False,
|
| 134 |
+
metadata={
|
| 135 |
+
"help": "Whether to use the authentication token for private datasets."
|
| 136 |
+
},
|
| 137 |
+
)
|
| 138 |
max_source_length: Optional[int] = field(
|
| 139 |
default=128,
|
| 140 |
metadata={
|
|
|
|
| 262 |
metadata={"help": "Log model to wandb at `save_steps` frequency."},
|
| 263 |
)
|
| 264 |
|
| 265 |
+
seed_model: int = field(
|
| 266 |
default=42,
|
| 267 |
+
metadata={
|
| 268 |
+
"help": "Random seed for the model that will be set at the beginning of training."
|
| 269 |
+
},
|
| 270 |
+
)
|
| 271 |
+
# default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
|
| 272 |
+
seed_dataset: int = field(
|
| 273 |
+
default=None,
|
| 274 |
+
metadata={
|
| 275 |
+
"help": "Random seed for the dataset that will be set at the beginning of training."
|
| 276 |
+
},
|
| 277 |
)
|
| 278 |
|
| 279 |
push_to_hub: bool = field(
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
def data_loader(
|
| 322 |
+
dataset: Dataset,
|
| 323 |
+
batch_size: int,
|
| 324 |
+
rng: jax.random.PRNGKey = None,
|
| 325 |
):
|
| 326 |
"""
|
| 327 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
|
|
|
| 329 |
"""
|
| 330 |
steps_per_epoch = len(dataset) // batch_size
|
| 331 |
|
| 332 |
+
if rng is not None:
|
| 333 |
batch_idx = jax.random.permutation(rng, len(dataset))
|
| 334 |
else:
|
| 335 |
batch_idx = jnp.arange(len(dataset))
|
|
|
|
| 449 |
data_args.dataset_repo_or_path,
|
| 450 |
data_files=data_files,
|
| 451 |
streaming=data_args.streaming,
|
| 452 |
+
use_auth_token=data_args.use_auth_token,
|
| 453 |
)
|
| 454 |
|
| 455 |
# Set up wandb run
|
|
|
|
| 501 |
|
| 502 |
# Create a custom model and initialize it randomly
|
| 503 |
model = CustomFlaxBartForConditionalGeneration(
|
| 504 |
+
config, seed=training_args.seed_model, dtype=getattr(jnp, model_args.dtype)
|
| 505 |
)
|
| 506 |
|
| 507 |
# Load tokenizer
|
|
|
|
| 579 |
else train_dataset.select(range(data_args.max_train_samples))
|
| 580 |
)
|
| 581 |
if data_args.streaming:
|
| 582 |
+
train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
|
| 583 |
+
else:
|
| 584 |
+
seed_dataset = (
|
| 585 |
+
training_args.seed_dataset
|
| 586 |
+
if training_args.seed_dataset is not None
|
| 587 |
+
else np.random.get_state()[1][0]
|
| 588 |
+
)
|
| 589 |
+
rng_dataset = jax.random.PRNGKey(seed_dataset)
|
| 590 |
if model.config.normalize_text:
|
| 591 |
train_dataset = (
|
| 592 |
train_dataset.map(normalize_text)
|
|
|
|
| 652 |
)
|
| 653 |
|
| 654 |
# Initialize our training
|
| 655 |
+
rng = jax.random.PRNGKey(training_args.seed_model)
|
| 656 |
rng, dropout_rng = jax.random.split(rng)
|
| 657 |
|
| 658 |
# Store some constant
|
|
|
|
| 833 |
if data_args.streaming:
|
| 834 |
eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
|
| 835 |
else:
|
| 836 |
+
eval_loader = data_loader(eval_dataset, eval_batch_size)
|
| 837 |
eval_steps = (
|
| 838 |
len_eval_dataset // eval_batch_size
|
| 839 |
if len_eval_dataset is not None
|
|
|
|
| 952 |
train_dataset.set_epoch(epoch) # shuffle dataset
|
| 953 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
| 954 |
else:
|
| 955 |
+
rng_dataset, input_rng = jax.random.split(rng_dataset)
|
| 956 |
+
train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
|
|
|
|
|
|
|
| 957 |
# train
|
| 958 |
for batch in tqdm(
|
| 959 |
train_loader,
|