|
|
|
|
|
import datasets |
|
import evaluate |
|
|
|
|
|
_DESCRIPTION = """ |
|
The Accuracy (ACC) metric is used to measure the proportion of correctly predicted sequences compared to the total number of sequences. |
|
This metric can handle both integer and string inputs by converting them to strings for comparison. |
|
The ACC ranges from 0 to 1, where 1 indicates perfect accuracy (all predictions are correct) and 0 indicates complete failure (no predictions are correct). |
|
It is particularly useful in tasks such as OCR, digit recognition, sequence prediction, and any task where exact matches are required. |
|
The accuracy can be calculated using the formula: |
|
|
|
ACC = (Number of Correct Predictions) / (Total Number of Predictions) |
|
|
|
Where a prediction is considered correct if it exactly matches the ground truth sequence after converting both to strings. |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
predictions (`list` of `str` or `int`): Predicted labels (can be strings or integers). |
|
references (`list` of `str` or `int`): Ground truth labels (can be strings or integers). |
|
Returns: |
|
acc (`float`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0. |
|
Examples: |
|
Example 1 - String inputs: |
|
>>> acc_metric = evaluate.load("Bekhouche/ACC") |
|
>>> results = acc_metric.compute(references=['123', '456', '789'], predictions=['123', '456', '789']) |
|
>>> print(results) |
|
{'acc': 1.0} |
|
|
|
Example 2 - Integer inputs: |
|
>>> results = acc_metric.compute(references=[123, 456, 789], predictions=[123, 456, 789]) |
|
>>> print(results) |
|
{'acc': 1.0} |
|
|
|
Example 3 - Mixed inputs: |
|
>>> results = acc_metric.compute(references=['123', 456, '789'], predictions=[123, '456', 789]) |
|
>>> print(results) |
|
{'acc': 1.0} |
|
""" |
|
|
|
_CITATION = """ |
|
@inproceedings{accuracy_metric, |
|
title={Accuracy as a fundamental metric for sequence prediction tasks}, |
|
author={Various Authors}, |
|
booktitle={Proceedings of various conferences}, |
|
pages={1--10}, |
|
year={2023}, |
|
organization={Various} |
|
} |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class ACC(evaluate.Metric): |
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Sequence(datasets.Value("string")), |
|
"references": datasets.Sequence(datasets.Value("string")), |
|
} |
|
if self.config_name == "multilabel" |
|
else { |
|
"predictions": datasets.Value("string"), |
|
"references": datasets.Value("string"), |
|
} |
|
), |
|
reference_urls=["https://huggingface.co/spaces/Bekhouche/ACC"], |
|
) |
|
|
|
def _compute(self, predictions, references): |
|
acc = 0.0 |
|
if isinstance(predictions, list) and isinstance(references, list): |
|
correct_predictions = sum([self._compute_acc(prediction, reference) for prediction, reference in zip(predictions, references)]) |
|
acc = correct_predictions / len(predictions) |
|
elif (isinstance(predictions, (str, int)) and isinstance(references, (str, int))): |
|
acc = self._compute_acc(predictions, references) |
|
else: |
|
raise ValueError("Predictions and references must be either a list[str/int] or str/int") |
|
return { |
|
"acc": float( |
|
acc |
|
) |
|
} |
|
|
|
@staticmethod |
|
def _compute_acc(prediction, reference): |
|
|
|
pred_str = str(prediction) |
|
ref_str = str(reference) |
|
return 1.0 if pred_str == ref_str else 0.0 |
|
|