Pedro Cuenca
commited on
Commit
路
32dc2d8
1
Parent(s):
df3c7bd
* Only perform validation if requested
Browse files* Disable rouge metric
* Add sanity check for tpus.
* Add training command.
- seq2seq/do_run.sh +9 -0
- seq2seq/run_seq2seq_flax.py +38 -35
seq2seq/do_run.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python run_seq2seq_flax.py \
|
| 2 |
+
--max_source_length 128 \
|
| 3 |
+
--train_file /data/CC12M/encoded-small-train.tsv \
|
| 4 |
+
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
| 5 |
+
--output_dir output \
|
| 6 |
+
--per_device_train_batch_size 16 \
|
| 7 |
+
--per_device_eval_batch_size 16 \
|
| 8 |
+
--do_train \
|
| 9 |
+
--do_eval \
|
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -413,6 +413,8 @@ def main():
|
|
| 413 |
#config.min_length = data_args.max_target_length # Set only in decoder?
|
| 414 |
#config.max_length = data_args.max_target_length # Set only in decoder?
|
| 415 |
|
|
|
|
|
|
|
| 416 |
|
| 417 |
# Create a custom model and initialize it randomly
|
| 418 |
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
|
@@ -534,7 +536,7 @@ def main():
|
|
| 534 |
)
|
| 535 |
|
| 536 |
# Metric
|
| 537 |
-
metric = load_metric("rouge")
|
| 538 |
|
| 539 |
def postprocess_text(preds, labels):
|
| 540 |
preds = [pred.strip() for pred in preds]
|
|
@@ -740,40 +742,41 @@ def main():
|
|
| 740 |
|
| 741 |
# ======================== Evaluating ==============================
|
| 742 |
eval_metrics = []
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
|
|
|
| 777 |
|
| 778 |
# Save metrics
|
| 779 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
|
| 413 |
#config.min_length = data_args.max_target_length # Set only in decoder?
|
| 414 |
#config.max_length = data_args.max_target_length # Set only in decoder?
|
| 415 |
|
| 416 |
+
print(f"TPUs: {jax.device_count()}")
|
| 417 |
+
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
| 418 |
|
| 419 |
# Create a custom model and initialize it randomly
|
| 420 |
model = CustomFlaxBartForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
|
|
|
| 536 |
)
|
| 537 |
|
| 538 |
# Metric
|
| 539 |
+
#metric = load_metric("rouge")
|
| 540 |
|
| 541 |
def postprocess_text(preds, labels):
|
| 542 |
preds = [pred.strip() for pred in preds]
|
|
|
|
| 742 |
|
| 743 |
# ======================== Evaluating ==============================
|
| 744 |
eval_metrics = []
|
| 745 |
+
if training_args.do_eval:
|
| 746 |
+
eval_preds = []
|
| 747 |
+
eval_labels = []
|
| 748 |
+
|
| 749 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
| 750 |
+
eval_steps = len(eval_dataset) // eval_batch_size
|
| 751 |
+
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
| 752 |
+
# Model forward
|
| 753 |
+
batch = next(eval_loader)
|
| 754 |
+
labels = batch["labels"]
|
| 755 |
+
|
| 756 |
+
metrics = p_eval_step(state.params, batch)
|
| 757 |
+
eval_metrics.append(metrics)
|
| 758 |
+
|
| 759 |
+
# generation
|
| 760 |
+
if data_args.predict_with_generate:
|
| 761 |
+
generated_ids = p_generate_step(state.params, batch)
|
| 762 |
+
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
| 763 |
+
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
| 764 |
+
|
| 765 |
+
# normalize eval metrics
|
| 766 |
+
eval_metrics = get_metrics(eval_metrics)
|
| 767 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 768 |
+
|
| 769 |
+
# compute ROUGE metrics
|
| 770 |
+
rouge_desc = ""
|
| 771 |
+
# if data_args.predict_with_generate:
|
| 772 |
+
# rouge_metrics = compute_metrics(eval_preds, eval_labels)
|
| 773 |
+
# eval_metrics.update(rouge_metrics)
|
| 774 |
+
# rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()])
|
| 775 |
+
|
| 776 |
+
# Print metrics and update progress bar
|
| 777 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})"
|
| 778 |
+
epochs.write(desc)
|
| 779 |
+
epochs.desc = desc
|
| 780 |
|
| 781 |
# Save metrics
|
| 782 |
if has_tensorboard and jax.process_index() == 0:
|