Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
import kornia.augmentation as K | |
class ImageAugmentations(nn.Module): | |
def __init__(self, output_size, augmentations_number, p=0.7, resize=True): | |
super().__init__() | |
self.output_size = output_size | |
self.augmentations_number = augmentations_number | |
self.augmentations = [ | |
K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), | |
# K.RandomPerspective(0.7, p=p), | |
] | |
self.resize = ( | |
nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) | |
if resize | |
else (lambda x: x) | |
) | |
def forward(self, image, mask, with_orig=True): | |
"""Extends the image and mask with identical augmentations | |
If the input consists of image I, and mask M, the extended augmented output will be: | |
[I_aug1, I_aug2, I_aug3, ...], [M_aug1, M_aug2, M_aug3, ...] | |
If with_orig=True, the extended augmented output will be: | |
[I, I_aug1, I_aug2, ...], [M, M_aug1, M_aug2, ...] | |
Args: | |
image: tensor of shape [1, C, H, W] | |
mask: tensor of shape [1, 1, H, W] | |
with_orig: if True, first returned image and mask will be un-augmented inputs | |
Returns: | |
tuple of (extended images of shape [augmentations_number, C, H, W], | |
extended masks of shape [augmentations_number, 1, H, W]) | |
""" | |
# Duplicate the inputs, in contrast to regular augmentations that do not change the number of samples | |
resized_images = self.resize(image) | |
resized_images = resized_images.repeat(self.augmentations_number, 1, 1, 1) | |
resized_masks = self.resize(mask) | |
resized_masks = resized_masks.repeat(self.augmentations_number, 1, 1, 1) | |
batch_size = image.shape[0] | |
if with_orig: | |
# At least one non-augmented image | |
non_aug_inputs = resized_images[:batch_size] | |
aug_inputs = resized_images[batch_size:] | |
non_aug_masks = resized_masks[:batch_size] | |
aug_masks = resized_masks[batch_size:] | |
for trans in self.augmentations: | |
trans_params = trans.forward_parameters(aug_inputs.shape) | |
aug_inputs = trans(aug_inputs, trans_params) | |
aug_masks = trans(aug_masks, trans_params) | |
updated_input_batch = torch.cat([non_aug_inputs, aug_inputs], dim=0) | |
updated_mask_batch = torch.cat([non_aug_masks, aug_masks], dim=0) | |
else: | |
aug_inputs = resized_images | |
aug_masks = resized_masks | |
for trans in self.augmentations: | |
trans_params = trans.forward_parameters(aug_inputs.shape) | |
aug_inputs = trans(aug_inputs, trans_params) | |
aug_masks = trans(aug_masks, trans_params) | |
updated_input_batch = aug_inputs | |
updated_mask_batch = aug_masks | |
return updated_input_batch, updated_mask_batch | |