JasonSmithSO's picture
Upload 777 files
0034848 verified
raw
history blame
3.36 kB
import random
from typing import Any, Dict, Optional, Tuple, Union
import cv2
import numpy as np
from skimage.measure import label
from ...core.transforms_interface import DualTransform, to_tuple
__all__ = ["MaskDropout"]
class MaskDropout(DualTransform):
"""
Image & mask augmentation that zero out mask and image regions corresponding
to randomly chosen object instance from mask.
Mask must be single-channel image, zero values treated as background.
Image can be any number of channels.
Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254
Args:
max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max]
image_fill_value: Fill value to use when filling image.
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images)
mask_fill_value: Fill value to use when filling mask.
Targets:
image, mask
Image types:
uint8, float32
"""
def __init__(
self,
max_objects: int = 1,
image_fill_value: Union[int, float, str] = 0,
mask_fill_value: Union[int, float] = 0,
always_apply: bool = False,
p: float = 0.5,
):
super(MaskDropout, self).__init__(always_apply, p)
self.max_objects = to_tuple(max_objects, 1)
self.image_fill_value = image_fill_value
self.mask_fill_value = mask_fill_value
@property
def targets_as_params(self):
return ["mask"]
def get_params_dependent_on_targets(self, params) -> Dict[str, Any]:
mask = params["mask"]
label_image, num_labels = label(mask, return_num=True)
if num_labels == 0:
dropout_mask = None
else:
objects_to_drop = random.randint(int(self.max_objects[0]), int(self.max_objects[1]))
objects_to_drop = min(num_labels, objects_to_drop)
if objects_to_drop == num_labels:
dropout_mask = mask > 0
else:
labels_index = random.sample(range(1, num_labels + 1), objects_to_drop)
dropout_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=bool)
for label_index in labels_index:
dropout_mask |= label_image == label_index
params.update({"dropout_mask": dropout_mask})
return params
def apply(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
if dropout_mask is None:
return img
if self.image_fill_value == "inpaint":
dropout_mask = dropout_mask.astype(np.uint8)
_, _, w, h = cv2.boundingRect(dropout_mask)
radius = min(3, max(w, h) // 2)
img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS)
else:
img = img.copy()
img[dropout_mask] = self.image_fill_value
return img
def apply_to_mask(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
if dropout_mask is None:
return img
img = img.copy()
img[dropout_mask] = self.mask_fill_value
return img
def get_transform_init_args_names(self) -> Tuple[str, ...]:
return "max_objects", "image_fill_value", "mask_fill_value"