Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Evaluate Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ BERTScore metric. """ | |
| import functools | |
| from contextlib import contextmanager | |
| import bert_score | |
| import datasets | |
| from packaging import version | |
| import evaluate | |
| def filter_logging_context(): | |
| def filter_log(record): | |
| return False if "This IS expected if you are initializing" in record.msg else True | |
| logger = datasets.utils.logging.get_logger("transformers.modeling_utils") | |
| logger.addFilter(filter_log) | |
| try: | |
| yield | |
| finally: | |
| logger.removeFilter(filter_log) | |
| _CITATION = """\ | |
| @inproceedings{bert-score, | |
| title={BERTScore: Evaluating Text Generation with BERT}, | |
| author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi}, | |
| booktitle={International Conference on Learning Representations}, | |
| year={2020}, | |
| url={https://openreview.net/forum?id=SkeHuCVFDr} | |
| } | |
| """ | |
| _DESCRIPTION = """\ | |
| BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference | |
| sentences by cosine similarity. | |
| It has been shown to correlate with human judgment on sentence-level and system-level evaluation. | |
| Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language | |
| generation tasks. | |
| See the project's README at https://github.com/Tiiiger/bert_score#readme for more information. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| BERTScore Metrics with the hashcode from a source against one or more references. | |
| Args: | |
| predictions (list of str): Prediction/candidate sentences. | |
| references (list of str or list of list of str): Reference sentences. | |
| lang (str): Language of the sentences; required (e.g. 'en'). | |
| model_type (str): Bert specification, default using the suggested | |
| model for the target language; has to specify at least one of | |
| `model_type` or `lang`. | |
| num_layers (int): The layer of representation to use, | |
| default using the number of layers tuned on WMT16 correlation data. | |
| verbose (bool): Turn on intermediate status update. | |
| idf (bool or dict): Use idf weighting; can also be a precomputed idf_dict. | |
| device (str): On which the contextual embedding model will be allocated on. | |
| If this argument is None, the model lives on cuda:0 if cuda is available. | |
| nthreads (int): Number of threads. | |
| batch_size (int): Bert score processing batch size, | |
| at least one of `model_type` or `lang`. `lang` needs to be | |
| specified when `rescale_with_baseline` is True. | |
| rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline. | |
| baseline_path (str): Customized baseline file. | |
| use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10. | |
| Returns: | |
| precision: Precision. | |
| recall: Recall. | |
| f1: F1 score. | |
| hashcode: Hashcode of the library. | |
| Examples: | |
| >>> predictions = ["hello there", "general kenobi"] | |
| >>> references = ["hello there", "general kenobi"] | |
| >>> bertscore = evaluate.load("bertscore") | |
| >>> results = bertscore.compute(predictions=predictions, references=references, lang="en") | |
| >>> print([round(v, 2) for v in results["f1"]]) | |
| [1.0, 1.0] | |
| """ | |
| class BERTScore(evaluate.Metric): | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| homepage="https://github.com/Tiiiger/bert_score", | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=[ | |
| datasets.Features( | |
| { | |
| "predictions": datasets.Value("string", id="sequence"), | |
| "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), | |
| } | |
| ), | |
| datasets.Features( | |
| { | |
| "predictions": datasets.Value("string", id="sequence"), | |
| "references": datasets.Value("string", id="sequence"), | |
| } | |
| ), | |
| ], | |
| codebase_urls=["https://github.com/Tiiiger/bert_score"], | |
| reference_urls=[ | |
| "https://github.com/Tiiiger/bert_score", | |
| "https://arxiv.org/abs/1904.09675", | |
| ], | |
| ) | |
| def _compute( | |
| self, | |
| predictions, | |
| references, | |
| lang=None, | |
| model_type=None, | |
| num_layers=None, | |
| verbose=False, | |
| idf=False, | |
| device=None, | |
| batch_size=64, | |
| nthreads=4, | |
| all_layers=False, | |
| rescale_with_baseline=False, | |
| baseline_path=None, | |
| use_fast_tokenizer=False, | |
| ): | |
| if isinstance(references[0], str): | |
| references = [[ref] for ref in references] | |
| if idf: | |
| idf_sents = [r for ref in references for r in ref] | |
| else: | |
| idf_sents = None | |
| get_hash = bert_score.utils.get_hash | |
| scorer = bert_score.BERTScorer | |
| if version.parse(bert_score.__version__) >= version.parse("0.3.10"): | |
| get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer) | |
| scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer) | |
| elif use_fast_tokenizer: | |
| raise ImportWarning( | |
| "To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of " | |
| "`bert-score` doesn't match this condition.\n" | |
| 'You can install it with `pip install "bert-score>=0.3.10"`.' | |
| ) | |
| if model_type is None: | |
| if lang is None: | |
| raise ValueError( | |
| "Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')" | |
| " must be specified" | |
| ) | |
| model_type = bert_score.utils.lang2model[lang.lower()] | |
| if num_layers is None: | |
| num_layers = bert_score.utils.model2layers[model_type] | |
| hashcode = get_hash( | |
| model=model_type, | |
| num_layers=num_layers, | |
| idf=idf, | |
| rescale_with_baseline=rescale_with_baseline, | |
| use_custom_baseline=baseline_path is not None, | |
| ) | |
| with filter_logging_context(): | |
| if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode: | |
| self.cached_bertscorer = scorer( | |
| model_type=model_type, | |
| num_layers=num_layers, | |
| batch_size=batch_size, | |
| nthreads=nthreads, | |
| all_layers=all_layers, | |
| idf=idf, | |
| idf_sents=idf_sents, | |
| device=device, | |
| lang=lang, | |
| rescale_with_baseline=rescale_with_baseline, | |
| baseline_path=baseline_path, | |
| ) | |
| (P, R, F) = self.cached_bertscorer.score( | |
| cands=predictions, | |
| refs=references, | |
| verbose=verbose, | |
| batch_size=batch_size, | |
| ) | |
| output_dict = { | |
| "precision": P.tolist(), | |
| "recall": R.tolist(), | |
| "f1": F.tolist(), | |
| "hashcode": hashcode, | |
| } | |
| return output_dict | |