Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from fairseq.models.bart import BARTModel | |
| import argparse | |
| XSUM_KWARGS = dict(beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3) | |
| CNN_KWARGS = dict(beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) | |
| def generate(bart, infile, outfile="bart_hypo.txt", bsz=32, n_obs=None, **eval_kwargs): | |
| count = 1 | |
| # if n_obs is not None: bsz = min(bsz, n_obs) | |
| with open(infile) as source, open(outfile, "w") as fout: | |
| sline = source.readline().strip() | |
| slines = [sline] | |
| for sline in source: | |
| if n_obs is not None and count > n_obs: | |
| break | |
| if count % bsz == 0: | |
| hypotheses_batch = bart.sample(slines, **eval_kwargs) | |
| for hypothesis in hypotheses_batch: | |
| fout.write(hypothesis + "\n") | |
| fout.flush() | |
| slines = [] | |
| slines.append(sline.strip()) | |
| count += 1 | |
| if slines != []: | |
| hypotheses_batch = bart.sample(slines, **eval_kwargs) | |
| for hypothesis in hypotheses_batch: | |
| fout.write(hypothesis + "\n") | |
| fout.flush() | |
| def main(): | |
| """ | |
| Usage:: | |
| python examples/bart/summarize.py \ | |
| --model-dir $HOME/bart.large.cnn \ | |
| --model-file model.pt \ | |
| --src $HOME/data-bin/cnn_dm/test.source | |
| """ | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model-dir", | |
| required=True, | |
| type=str, | |
| default="bart.large.cnn/", | |
| help="path containing model file and src_dict.txt", | |
| ) | |
| parser.add_argument( | |
| "--model-file", | |
| default="checkpoint_best.pt", | |
| help="where in model_dir are weights saved", | |
| ) | |
| parser.add_argument( | |
| "--src", default="test.source", help="text to summarize", type=str | |
| ) | |
| parser.add_argument( | |
| "--out", default="test.hypo", help="where to save summaries", type=str | |
| ) | |
| parser.add_argument("--bsz", default=32, help="where to save summaries", type=int) | |
| parser.add_argument( | |
| "--n", default=None, help="how many examples to summarize", type=int | |
| ) | |
| parser.add_argument( | |
| "--xsum-kwargs", | |
| action="store_true", | |
| default=False, | |
| help="if true use XSUM_KWARGS else CNN_KWARGS", | |
| ) | |
| args = parser.parse_args() | |
| eval_kwargs = XSUM_KWARGS if args.xsum_kwargs else CNN_KWARGS | |
| if args.model_dir == "pytorch/fairseq": | |
| bart = torch.hub.load("pytorch/fairseq", args.model_file) | |
| else: | |
| bart = BARTModel.from_pretrained( | |
| args.model_dir, | |
| checkpoint_file=args.model_file, | |
| data_name_or_path=args.model_dir, | |
| ) | |
| bart = bart.eval() | |
| if torch.cuda.is_available(): | |
| bart = bart.cuda().half() | |
| generate( | |
| bart, args.src, bsz=args.bsz, n_obs=args.n, outfile=args.out, **eval_kwargs | |
| ) | |
| if __name__ == "__main__": | |
| main() | |