Spaces:
Runtime error
Runtime error
Merge branch 'add-tokenizer-save' into feat-model
Browse filesFormer-commit-id: 2cfaef4a020f43332a8f33b6a9bd8221ec9fae34
seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -818,13 +818,16 @@ def main():
|
|
| 818 |
params=params,
|
| 819 |
)
|
| 820 |
|
|
|
|
|
|
|
|
|
|
| 821 |
# save state
|
| 822 |
state = unreplicate(state)
|
| 823 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
| 824 |
f.write(to_bytes(state.opt_state))
|
| 825 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
| 826 |
json.dump({'step': state.step.item()}, f)
|
| 827 |
-
|
| 828 |
# save to W&B
|
| 829 |
if data_args.log_model:
|
| 830 |
metadata = {'step': step, 'epoch': epoch}
|
|
@@ -834,6 +837,11 @@ def main():
|
|
| 834 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 835 |
)
|
| 836 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 838 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
| 839 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
|
|
|
| 818 |
params=params,
|
| 819 |
)
|
| 820 |
|
| 821 |
+
# save tokenizer
|
| 822 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
| 823 |
+
|
| 824 |
# save state
|
| 825 |
state = unreplicate(state)
|
| 826 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
| 827 |
f.write(to_bytes(state.opt_state))
|
| 828 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
| 829 |
json.dump({'step': state.step.item()}, f)
|
| 830 |
+
|
| 831 |
# save to W&B
|
| 832 |
if data_args.log_model:
|
| 833 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
|
| 837 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
| 838 |
)
|
| 839 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
| 840 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
|
| 841 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
|
| 842 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
|
| 843 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'added_tokens.json'))
|
| 844 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
|
| 845 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
| 846 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
| 847 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|