Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,691 Bytes
90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import numpy as np
import torch
def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)):
"""Generate a random sqaure mask for inpainting
"""
B, C, H, W = img.shape
h, w = mask_shape
margin_height, margin_width = margin
maxt = image_size - margin_height - h
maxl = image_size - margin_width - w
# bb
t = np.random.randint(margin_height, maxt)
l = np.random.randint(margin_width, maxl)
# make mask
mask = torch.ones([B, C, H, W], device=img.device)
mask[..., t:t+h, l:l+w] = 0
return mask, t, t+h, l, l+w
class MaskGenerator:
def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None,
image_size=256, margin=(16, 16)):
"""
(mask_len_range): given in (min, max) tuple.
Specifies the range of box size in each dimension
(mask_prob_range): for the case of random masking,
specify the probability of individual pixels being masked
"""
assert mask_type in ['box', 'random', 'both', 'extreme']
self.mask_type = mask_type
self.mask_len_range = mask_len_range
self.mask_prob_range = mask_prob_range
self.image_size = image_size
self.margin = margin
def _retrieve_box(self, img):
l, h = self.mask_len_range
l, h = int(l), int(h)
mask_h = np.random.randint(l, h)
mask_w = np.random.randint(l, h)
mask, t, tl, w, wh = random_sq_bbox(img,
mask_shape=(mask_h, mask_w),
image_size=self.image_size,
margin=self.margin)
return mask, t, tl, w, wh
def _retrieve_random(self, img):
total = self.image_size ** 2
# random pixel sampling
l, h = self.mask_prob_range
prob = np.random.uniform(l, h)
mask_vec = torch.ones([1, self.image_size * self.image_size])
samples = np.random.choice(self.image_size * self.image_size, int(total * prob), replace=False)
mask_vec[:, samples] = 0
mask_b = mask_vec.view(1, self.image_size, self.image_size)
mask_b = mask_b.repeat(3, 1, 1)
mask = torch.ones_like(img, device=img.device)
mask[:, ...] = mask_b
return mask
def __call__(self, img):
if self.mask_type == 'random':
mask = self._retrieve_random(img)
return mask
elif self.mask_type == 'box':
mask, t, th, w, wl = self._retrieve_box(img)
return mask
elif self.mask_type == 'extreme':
mask, t, th, w, wl = self._retrieve_box(img)
mask = 1. - mask
return mask
|