Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| from torchmetrics.retrieval.base import RetrievalMetric | |
| from torchmetrics.utilities.checks import _check_retrieval_functional_inputs | |
| from deepscreen.models.metrics.rie import calc_rie | |
| class BEDROC(RetrievalMetric): | |
| is_differentiable: bool = False | |
| higher_is_better: bool = True | |
| full_state_update: bool = False | |
| def __init__( | |
| self, | |
| alpha: float = 80.5, | |
| ): | |
| super().__init__() | |
| self.alpha = alpha | |
| def _metric(self, preds: Tensor, target: Tensor) -> Tensor: | |
| preds, target = _check_retrieval_functional_inputs(preds, target) | |
| n_total = target.size(0) | |
| n_actives = target.sum() | |
| if n_actives == 0: | |
| return torch.tensor(0.0, device=preds.device) | |
| elif n_actives == n_total: | |
| return torch.tensor(1.0, device=preds.device) | |
| r_a = n_actives / n_total | |
| exp_a = torch.exp(torch.tensor(self.alpha)) | |
| idx = torch.argsort(preds, descending=True, stable=True) | |
| active_ranks = torch.take(target, idx).nonzero() + 1 | |
| rie = calc_rie(n_total, active_ranks, r_a, exp_a) | |
| rie_min = (1 - exp_a ** r_a) / (r_a * (1 - exp_a)) | |
| rie_max = (1 - exp_a ** (-r_a)) / (r_a * (1 - exp_a ** (-1))) | |
| return (rie - rie_min) / (rie_max - rie_min) | |
| def plot(self, val=None, ax=None): | |
| return self._plot(val, ax) | |