| # Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # -------------------------------------------------------- | |
| # Masking utils | |
| # -------------------------------------------------------- | |
| import torch | |
| import torch.nn as nn | |
| class RandomMask(nn.Module): | |
| """ | |
| random masking | |
| """ | |
| def __init__(self, num_patches, mask_ratio): | |
| super().__init__() | |
| self.num_patches = num_patches | |
| self.num_mask = int(mask_ratio * self.num_patches) | |
| def __call__(self, x): | |
| noise = torch.rand(x.size(0), self.num_patches, device=x.device) | |
| argsort = torch.argsort(noise, dim=1) | |
| return argsort < self.num_mask | |