Spaces:
Runtime error
Runtime error
fix: remove breakpoint
Browse files
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -922,11 +922,8 @@ def main():
|
|
| 922 |
eval_metrics.append(metrics)
|
| 923 |
|
| 924 |
# normalize eval metrics
|
| 925 |
-
breakpoint()
|
| 926 |
eval_metrics = get_metrics(eval_metrics)
|
| 927 |
-
breakpoint()
|
| 928 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 929 |
-
breakpoint()
|
| 930 |
|
| 931 |
# log metrics
|
| 932 |
wandb_log(eval_metrics, step=global_step, prefix="eval")
|
|
|
|
| 922 |
eval_metrics.append(metrics)
|
| 923 |
|
| 924 |
# normalize eval metrics
|
|
|
|
| 925 |
eval_metrics = get_metrics(eval_metrics)
|
|
|
|
| 926 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
|
|
|
| 927 |
|
| 928 |
# log metrics
|
| 929 |
wandb_log(eval_metrics, step=global_step, prefix="eval")
|