Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Signal-to-Reconstruction Error (SRE) metric.""" | |
| import evaluate | |
| import datasets | |
| import numpy as np | |
| _DESCRIPTION = """\ | |
| Compute the Signal-to-Reconstruction Error (SRE) metric. This metric is commonly used to | |
| asses the performance of denoising, super-resolution and style transfer algorithms in | |
| audio and image processing. | |
| """ | |
| _CITATION = """\ | |
| @InProceedings{huggingface:module, | |
| title = {A great new module}, | |
| authors={huggingface, Inc.}, | |
| year={2020} | |
| } | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Args: | |
| predictions (`list` of `np.array`): Predicted labels. | |
| references (`list` of `np.array`): Ground truth labels. | |
| sample_weight (`list` of `float`): Sample weights Defaults to None. | |
| Returns: | |
| sre (`float`): Signal-to-Reconstruction Error (SRE) metric. The SRE values are | |
| positive and they are expressed in decibels (dB). The higher the SRE value, the better. | |
| Examples: | |
| Example 1-A simple example | |
| >>> sre = evaluate.load("jpxkqx/signal_to_reconstruction_error") | |
| >>> results = sre.compute(references=[[0, 0], [-1, -1]], predictions=[[0, 1], [0, 0]]) | |
| >>> print(results) | |
| {"Signal-to-Reconstruction Error": 23.01} | |
| """ | |
| def signal_reconstruction_error(y_true: np.array, y_hat: np.array) -> np.array: | |
| return 10 * np.log10(np.sum(y_true ** 2) / np.sum((y_true - y_hat) ** 2)) | |
| class SignaltoReconstrutionError(evaluate.Metric): | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features(self._get_feature_types()), | |
| homepage="https://huggingface.co/spaces/jpxkqx/signal_to_reconstrution_error", | |
| ) | |
| def _get_feature_types(self): | |
| if self.config_name == "multilist": | |
| return { | |
| # 1st Seq - num_samples, 2nd Seq - Height, 3rd Seq - Width | |
| "predictions": datasets.Sequence( | |
| datasets.Sequence(datasets.Sequence(datasets.Value("float32"))) | |
| ), | |
| "references": datasets.Sequence( | |
| datasets.Sequence(datasets.Sequence(datasets.Value("float32"))) | |
| ), | |
| } | |
| else: | |
| return { | |
| # 1st Seq - Height, 2rd Seq - Width | |
| "predictions": datasets.Sequence( | |
| datasets.Sequence(datasets.Value("float32")) | |
| ), | |
| "references": datasets.Sequence( | |
| datasets.Sequence(datasets.Value("float32")) | |
| ), | |
| } | |
| def _compute(self, predictions, references, sample_weight=None): | |
| """Returns the scores""" | |
| samples = zip(np.array(references), np.array(predictions)) | |
| psnrs = list(map(lambda args: signal_reconstruction_error(*args), samples)) | |
| return { | |
| "Signal-to-Reconstruction Error": np.average(psnrs, weights=sample_weight) | |
| } | |