Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import torch | |
| # TODO | |
| class BalancedPositiveNegativeSampler(object): | |
| """ | |
| This class samples batches, | |
| ensuring that they contain a fixed proportion of positives | |
| """ | |
| def __init__(self, batch_size_per_image, positive_fraction): | |
| """ | |
| Arguments: | |
| batch_size_per_image (int): number of elements to be selected per image | |
| positive_fraction (float): percentace of positive elements per batch | |
| """ | |
| self.batch_size_per_image = batch_size_per_image | |
| self.positive_fraction = positive_fraction | |
| def __call__(self, matched_idxs): | |
| """ | |
| Arguments: | |
| matched idxs: list of tensors containing -1, 0 or positive values. | |
| Each tensor corresponds to a specific image. | |
| -1 values are ignored, 0 are considered as negatives and > 0 as | |
| positives. | |
| Returns: | |
| pos_idx (list[tensor]) | |
| neg_idx (list[tensor]) | |
| Returns two lists of binary masks for each image. | |
| The first list contains the positive elements that were selected, | |
| and the second list the negative example. | |
| """ | |
| pos_idx = [] | |
| neg_idx = [] | |
| for matched_idxs_per_image in matched_idxs: | |
| positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) | |
| negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) | |
| num_pos = int(self.batch_size_per_image * self.positive_fraction) | |
| # protect against not enough positive examples | |
| num_pos = min(positive.numel(), num_pos) | |
| num_neg = self.batch_size_per_image - num_pos | |
| # protect against not enough negative examples | |
| num_neg = min(negative.numel(), num_neg) | |
| # randomly select positive and negative examples | |
| perm1 = torch.randperm(positive.numel())[:num_pos] | |
| perm2 = torch.randperm(negative.numel())[:num_neg] | |
| pos_idx_per_image = positive[perm1] | |
| neg_idx_per_image = negative[perm2] | |
| # create binary mask from indices | |
| pos_idx_per_image_mask = torch.zeros_like( | |
| matched_idxs_per_image, dtype=torch.bool | |
| ) | |
| neg_idx_per_image_mask = torch.zeros_like( | |
| matched_idxs_per_image, dtype=torch.bool | |
| ) | |
| pos_idx_per_image_mask[pos_idx_per_image] = 1 | |
| neg_idx_per_image_mask[neg_idx_per_image] = 1 | |
| pos_idx.append(pos_idx_per_image_mask) | |
| neg_idx.append(neg_idx_per_image_mask) | |
| return pos_idx, neg_idx | |