Spaces:
Running
Running
Merge pull request #22 from borisdayma/feat-axis
Browse files- seq2seq/run_seq2seq_flax.py +25 -57
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -57,7 +57,6 @@ from transformers import (
|
|
| 57 |
FlaxBartForConditionalGeneration,
|
| 58 |
HfArgumentParser,
|
| 59 |
TrainingArguments,
|
| 60 |
-
is_tensorboard_available,
|
| 61 |
)
|
| 62 |
from transformers.models.bart.modeling_flax_bart import *
|
| 63 |
from transformers.file_utils import is_offline_mode
|
|
@@ -229,12 +228,6 @@ class DataTrainingArguments:
|
|
| 229 |
"value if set."
|
| 230 |
},
|
| 231 |
)
|
| 232 |
-
eval_interval: Optional[int] = field(
|
| 233 |
-
default=400,
|
| 234 |
-
metadata={
|
| 235 |
-
"help": "Evaluation will be performed every eval_interval steps"
|
| 236 |
-
},
|
| 237 |
-
)
|
| 238 |
log_model: bool = field(
|
| 239 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 240 |
)
|
|
@@ -327,19 +320,6 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
| 327 |
yield batch
|
| 328 |
|
| 329 |
|
| 330 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
| 331 |
-
summary_writer.scalar("train_time", train_time, step)
|
| 332 |
-
|
| 333 |
-
train_metrics = get_metrics(train_metrics)
|
| 334 |
-
for key, vals in train_metrics.items():
|
| 335 |
-
tag = f"train_epoch/{key}"
|
| 336 |
-
for i, val in enumerate(vals):
|
| 337 |
-
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
| 338 |
-
|
| 339 |
-
for metric_name, value in eval_metrics.items():
|
| 340 |
-
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
| 341 |
-
|
| 342 |
-
|
| 343 |
def create_learning_rate_fn(
|
| 344 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
| 345 |
) -> Callable[[int], jnp.array]:
|
|
@@ -356,6 +336,14 @@ def create_learning_rate_fn(
|
|
| 356 |
return schedule_fn
|
| 357 |
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
def main():
|
| 360 |
# See all possible arguments in src/transformers/training_args.py
|
| 361 |
# or by passing the --help flag to this script.
|
|
@@ -369,6 +357,9 @@ def main():
|
|
| 369 |
else:
|
| 370 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 371 |
|
|
|
|
|
|
|
|
|
|
| 372 |
if (
|
| 373 |
os.path.exists(training_args.output_dir)
|
| 374 |
and os.listdir(training_args.output_dir)
|
|
@@ -382,13 +373,16 @@ def main():
|
|
| 382 |
|
| 383 |
# Set up wandb run
|
| 384 |
wandb.init(
|
| 385 |
-
sync_tensorboard=True,
|
| 386 |
entity='wandb',
|
| 387 |
project='hf-flax-dalle-mini',
|
| 388 |
job_type='Seq2SeqVQGAN',
|
| 389 |
config=parser.parse_args()
|
| 390 |
)
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
# Make one log on every process with the configuration for debugging.
|
| 393 |
pylogging.basicConfig(
|
| 394 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
@@ -583,24 +577,6 @@ def main():
|
|
| 583 |
result = {k: round(v, 4) for k, v in result.items()}
|
| 584 |
return result
|
| 585 |
|
| 586 |
-
# Enable tensorboard only on the master node
|
| 587 |
-
has_tensorboard = is_tensorboard_available()
|
| 588 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 589 |
-
try:
|
| 590 |
-
from flax.metrics.tensorboard import SummaryWriter
|
| 591 |
-
|
| 592 |
-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
| 593 |
-
except ImportError as ie:
|
| 594 |
-
has_tensorboard = False
|
| 595 |
-
logger.warning(
|
| 596 |
-
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
| 597 |
-
)
|
| 598 |
-
else:
|
| 599 |
-
logger.warning(
|
| 600 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
| 601 |
-
"Please run pip install tensorboard to enable."
|
| 602 |
-
)
|
| 603 |
-
|
| 604 |
# Initialize our training
|
| 605 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 606 |
rng, dropout_rng = jax.random.split(rng)
|
|
@@ -780,10 +756,8 @@ def main():
|
|
| 780 |
eval_metrics = get_metrics(eval_metrics)
|
| 781 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 782 |
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
wandb.log({"eval/step": global_step})
|
| 786 |
-
wandb.log({f"eval/{k}": jax.device_get(v)})
|
| 787 |
|
| 788 |
# compute ROUGE metrics
|
| 789 |
rouge_desc = ""
|
|
@@ -796,6 +770,7 @@ def main():
|
|
| 796 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 797 |
epochs.write(desc)
|
| 798 |
epochs.desc = desc
|
|
|
|
| 799 |
return eval_metrics
|
| 800 |
|
| 801 |
for epoch in epochs:
|
|
@@ -804,7 +779,6 @@ def main():
|
|
| 804 |
|
| 805 |
# Create sampling rng
|
| 806 |
rng, input_rng = jax.random.split(rng)
|
| 807 |
-
train_metrics = []
|
| 808 |
|
| 809 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 810 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
@@ -814,32 +788,26 @@ def main():
|
|
| 814 |
global_step +=1
|
| 815 |
batch = next(train_loader)
|
| 816 |
state, train_metric = p_train_step(state, batch)
|
| 817 |
-
train_metrics.append(train_metric)
|
| 818 |
|
| 819 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
wandb.log({"train/step": global_step})
|
| 823 |
-
wandb.log({f"train/{k}": jax.device_get(v)})
|
| 824 |
|
| 825 |
-
if global_step %
|
| 826 |
run_evaluation()
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
train_time += time.time() - train_start
|
| 829 |
-
|
| 830 |
train_metric = unreplicate(train_metric)
|
| 831 |
-
|
| 832 |
epochs.write(
|
| 833 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 834 |
)
|
| 835 |
|
|
|
|
| 836 |
eval_metrics = run_evaluation()
|
| 837 |
|
| 838 |
-
# Save metrics
|
| 839 |
-
if has_tensorboard and jax.process_index() == 0:
|
| 840 |
-
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
| 841 |
-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
| 842 |
-
|
| 843 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 844 |
if jax.process_index() == 0:
|
| 845 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
|
|
|
| 57 |
FlaxBartForConditionalGeneration,
|
| 58 |
HfArgumentParser,
|
| 59 |
TrainingArguments,
|
|
|
|
| 60 |
)
|
| 61 |
from transformers.models.bart.modeling_flax_bart import *
|
| 62 |
from transformers.file_utils import is_offline_mode
|
|
|
|
| 228 |
"value if set."
|
| 229 |
},
|
| 230 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
log_model: bool = field(
|
| 232 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 233 |
)
|
|
|
|
| 320 |
yield batch
|
| 321 |
|
| 322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
def create_learning_rate_fn(
|
| 324 |
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float, no_decay: bool
|
| 325 |
) -> Callable[[int], jnp.array]:
|
|
|
|
| 336 |
return schedule_fn
|
| 337 |
|
| 338 |
|
| 339 |
+
def wandb_log(metrics, step=None, prefix=None):
|
| 340 |
+
if jax.process_index() == 0:
|
| 341 |
+
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
| 342 |
+
if step is not None:
|
| 343 |
+
log_metrics = {**log_metrics, 'train/step': step}
|
| 344 |
+
wandb.log(log_metrics)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
def main():
|
| 348 |
# See all possible arguments in src/transformers/training_args.py
|
| 349 |
# or by passing the --help flag to this script.
|
|
|
|
| 357 |
else:
|
| 358 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 359 |
|
| 360 |
+
logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
|
| 361 |
+
training_args.eval_steps = 400
|
| 362 |
+
|
| 363 |
if (
|
| 364 |
os.path.exists(training_args.output_dir)
|
| 365 |
and os.listdir(training_args.output_dir)
|
|
|
|
| 373 |
|
| 374 |
# Set up wandb run
|
| 375 |
wandb.init(
|
|
|
|
| 376 |
entity='wandb',
|
| 377 |
project='hf-flax-dalle-mini',
|
| 378 |
job_type='Seq2SeqVQGAN',
|
| 379 |
config=parser.parse_args()
|
| 380 |
)
|
| 381 |
|
| 382 |
+
# set default x-axis as 'train/step'
|
| 383 |
+
wandb.define_metric('train/step')
|
| 384 |
+
wandb.define_metric('*', step_metric='train/step')
|
| 385 |
+
|
| 386 |
# Make one log on every process with the configuration for debugging.
|
| 387 |
pylogging.basicConfig(
|
| 388 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
|
| 577 |
result = {k: round(v, 4) for k, v in result.items()}
|
| 578 |
return result
|
| 579 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
# Initialize our training
|
| 581 |
rng = jax.random.PRNGKey(training_args.seed)
|
| 582 |
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
| 756 |
eval_metrics = get_metrics(eval_metrics)
|
| 757 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 758 |
|
| 759 |
+
# log metrics
|
| 760 |
+
wandb_log(eval_metrics, step=global_step, prefix='eval')
|
|
|
|
|
|
|
| 761 |
|
| 762 |
# compute ROUGE metrics
|
| 763 |
rouge_desc = ""
|
|
|
|
| 770 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 771 |
epochs.write(desc)
|
| 772 |
epochs.desc = desc
|
| 773 |
+
|
| 774 |
return eval_metrics
|
| 775 |
|
| 776 |
for epoch in epochs:
|
|
|
|
| 779 |
|
| 780 |
# Create sampling rng
|
| 781 |
rng, input_rng = jax.random.split(rng)
|
|
|
|
| 782 |
|
| 783 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 784 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
|
|
|
| 788 |
global_step +=1
|
| 789 |
batch = next(train_loader)
|
| 790 |
state, train_metric = p_train_step(state, batch)
|
|
|
|
| 791 |
|
| 792 |
if global_step % data_args.log_interval == 0 and jax.process_index() == 0:
|
| 793 |
+
# log metrics
|
| 794 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
|
|
|
|
|
| 795 |
|
| 796 |
+
if global_step % training_args.eval_steps == 0:
|
| 797 |
run_evaluation()
|
| 798 |
+
|
| 799 |
+
# log final train metrics
|
| 800 |
+
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
| 801 |
|
| 802 |
train_time += time.time() - train_start
|
|
|
|
| 803 |
train_metric = unreplicate(train_metric)
|
|
|
|
| 804 |
epochs.write(
|
| 805 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
| 806 |
)
|
| 807 |
|
| 808 |
+
# Final evaluation
|
| 809 |
eval_metrics = run_evaluation()
|
| 810 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
# save checkpoint after each epoch and push checkpoint to the hub
|
| 812 |
if jax.process_index() == 0:
|
| 813 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|