|
import random |
|
import numpy as np |
|
|
|
import torch |
|
import torchvision.transforms as T |
|
|
|
|
|
class RandomRotate90(): |
|
def __call__(self, x): |
|
x = x.rot90(random.randint(0, 3), dims=(-1, -2)) |
|
return x |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ |
|
|
|
|
|
class AddGaussianNoise(): |
|
def __init__(self, std=0.01): |
|
self.std = std |
|
|
|
def __call__(self, x): |
|
|
|
|
|
|
|
|
|
|
|
return x + torch.randn_like(x) * self.std |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + f'(std={self.std})' |
|
|
|
|
|
def set_global_seed(seed): |
|
torch.random.manual_seed(seed) |
|
np.random.seed(seed % (2**32 - 1)) |
|
random.seed(seed) |
|
|
|
|
|
class ComposeState(T.Compose): |
|
def __init__(self, transforms): |
|
self.transforms = [] |
|
self.mask_transforms = [] |
|
|
|
for t in transforms: |
|
apply_for_mask = True |
|
if isinstance(t, tuple): |
|
t, apply_for_mask = t |
|
self.transforms.append(t) |
|
if apply_for_mask: |
|
self.mask_transforms.append(t) |
|
|
|
self.seed = None |
|
|
|
|
|
def __call__(self, x, retain_state=False, mask_transform=False): |
|
if self.seed is not None: |
|
set_global_seed(self.seed) |
|
if retain_state: |
|
self.seed = self.seed or torch.seed() |
|
set_global_seed(self.seed) |
|
else: |
|
self.seed = None |
|
|
|
transforms = self.transforms if not mask_transform else self.mask_transforms |
|
for t in transforms: |
|
x = t(x) |
|
return x |
|
|
|
|
|
augmentation_weak = ComposeState([ |
|
T.RandomHorizontalFlip(), |
|
T.RandomVerticalFlip(), |
|
RandomRotate90(), |
|
]) |
|
|
|
|
|
augmentation_strong = ComposeState([ |
|
T.RandomHorizontalFlip(p=0.5), |
|
T.RandomVerticalFlip(p=0.5), |
|
T.RandomApply([T.RandomRotation(90)], p=0.5), |
|
|
|
(T.RandomApply([AddGaussianNoise(std=0.0005)], p=0.5), False), |
|
(T.RandomAdjustSharpness(0.5, p=0.5), False), |
|
]) |
|
|
|
|
|
def get_augmentation(type): |
|
if type == 'none': |
|
return None |
|
if type == 'weak': |
|
return augmentation_weak |
|
if type == 'strong': |
|
return augmentation_strong |
|
|
|
|
|
if __name__ == '__main__': |
|
import os |
|
if not os.path.exists('README.md'): |
|
os.chdir('..') |
|
|
|
|
|
from dataset import get_dataset |
|
import matplotlib.pyplot as plt |
|
|
|
dataset = get_dataset('DS') |
|
img, mask = dataset[10] |
|
mask = (mask + 0.2) / 1.2 |
|
|
|
plt.figure(figsize=(14, 8)) |
|
plt.subplot(121) |
|
plt.imshow(img) |
|
plt.subplot(122) |
|
plt.imshow(mask) |
|
plt.suptitle('no augmentation') |
|
plt.show() |
|
|
|
from utils.base import np2torch, torch2np |
|
img, mask = np2torch(img), np2torch(mask) |
|
|
|
|
|
augmentation = get_augmentation('strong') |
|
|
|
set_global_seed(1) |
|
|
|
for i in range(1, 4): |
|
plt.figure(figsize=(14, 8)) |
|
plt.subplot(121) |
|
plt.imshow(torch2np(augmentation(img.unsqueeze(0), retain_state=True)).squeeze()) |
|
plt.subplot(122) |
|
plt.imshow(torch2np(augmentation(mask.unsqueeze(0), mask_transform=True)).squeeze()) |
|
plt.suptitle(f'augmentation test {i}') |
|
plt.show() |
|
|