ACC / ACC.py
Bekhouche's picture
update ACC.py
8202170
import datasets
import evaluate
_DESCRIPTION = """
The Accuracy (ACC) metric is used to measure the proportion of correctly predicted sequences compared to the total number of sequences.
This metric can handle both integer and string inputs by converting them to strings for comparison.
The ACC ranges from 0 to 1, where 1 indicates perfect accuracy (all predictions are correct) and 0 indicates complete failure (no predictions are correct).
It is particularly useful in tasks such as OCR, digit recognition, sequence prediction, and any task where exact matches are required.
The accuracy can be calculated using the formula:
ACC = (Number of Correct Predictions) / (Total Number of Predictions)
Where a prediction is considered correct if it exactly matches the ground truth sequence after converting both to strings.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of `str` or `int`): Predicted labels (can be strings or integers).
references (`list` of `str` or `int`): Ground truth labels (can be strings or integers).
Returns:
acc (`float`): Accuracy score. Minimum possible value is 0. Maximum possible value is 1.0.
Examples:
Example 1 - String inputs:
>>> acc_metric = evaluate.load("Bekhouche/ACC")
>>> results = acc_metric.compute(references=['123', '456', '789'], predictions=['123', '456', '789'])
>>> print(results)
{'acc': 1.0}
Example 2 - Integer inputs:
>>> results = acc_metric.compute(references=[123, 456, 789], predictions=[123, 456, 789])
>>> print(results)
{'acc': 1.0}
Example 3 - Mixed inputs:
>>> results = acc_metric.compute(references=['123', 456, '789'], predictions=[123, '456', 789])
>>> print(results)
{'acc': 1.0}
"""
_CITATION = """
@inproceedings{accuracy_metric,
title={Accuracy as a fundamental metric for sequence prediction tasks},
author={Various Authors},
booktitle={Proceedings of various conferences},
pages={1--10},
year={2023},
organization={Various}
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class ACC(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string")),
"references": datasets.Sequence(datasets.Value("string")),
}
if self.config_name == "multilabel"
else {
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://huggingface.co/spaces/Bekhouche/ACC"],
)
def _compute(self, predictions, references):
acc = 0.0
if isinstance(predictions, list) and isinstance(references, list):
correct_predictions = sum([self._compute_acc(prediction, reference) for prediction, reference in zip(predictions, references)])
acc = correct_predictions / len(predictions)
elif (isinstance(predictions, (str, int)) and isinstance(references, (str, int))):
acc = self._compute_acc(predictions, references)
else:
raise ValueError("Predictions and references must be either a list[str/int] or str/int")
return {
"acc": float(
acc
)
}
@staticmethod
def _compute_acc(prediction, reference):
# Convert both prediction and reference to strings for comparison
pred_str = str(prediction)
ref_str = str(reference)
return 1.0 if pred_str == ref_str else 0.0