raw2logit / utils /augmentation.py
willis
reorganize
0220054
import random
import numpy as np
import torch
import torchvision.transforms as T
class RandomRotate90(): # Note: not the same as T.RandomRotation(90)
def __call__(self, x):
x = x.rot90(random.randint(0, 3), dims=(-1, -2))
return x
def __repr__(self):
return self.__class__.__name__
class AddGaussianNoise():
def __init__(self, std=0.01):
self.std = std
def __call__(self, x):
# noise = torch.randn_like(x) * self.std
# out = x + noise
# debug(x)
# debug(noise)
# debug(out)
return x + torch.randn_like(x) * self.std
def __repr__(self):
return self.__class__.__name__ + f'(std={self.std})'
def set_global_seed(seed):
torch.random.manual_seed(seed)
np.random.seed(seed % (2**32 - 1))
random.seed(seed)
class ComposeState(T.Compose):
def __init__(self, transforms):
self.transforms = []
self.mask_transforms = []
for t in transforms:
apply_for_mask = True
if isinstance(t, tuple):
t, apply_for_mask = t
self.transforms.append(t)
if apply_for_mask:
self.mask_transforms.append(t)
self.seed = None
# @debug
def __call__(self, x, retain_state=False, mask_transform=False):
if self.seed is not None: # retain previous state
set_global_seed(self.seed)
if retain_state: # save state for next call
self.seed = self.seed or torch.seed()
set_global_seed(self.seed)
else:
self.seed = None # reset / ignore state
transforms = self.transforms if not mask_transform else self.mask_transforms
for t in transforms:
x = t(x)
return x
augmentation_weak = ComposeState([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip(),
RandomRotate90(),
])
augmentation_strong = ComposeState([
T.RandomHorizontalFlip(p=0.5),
T.RandomVerticalFlip(p=0.5),
T.RandomApply([T.RandomRotation(90)], p=0.5),
# (transform, apply_to_mask=True)
(T.RandomApply([AddGaussianNoise(std=0.0005)], p=0.5), False),
(T.RandomAdjustSharpness(0.5, p=0.5), False),
])
def get_augmentation(type):
if type == 'none':
return None
if type == 'weak':
return augmentation_weak
if type == 'strong':
return augmentation_strong
if __name__ == '__main__':
import os
if not os.path.exists('README.md'):
os.chdir('..')
# from utils.debug import debug
from dataset import get_dataset
import matplotlib.pyplot as plt
dataset = get_dataset('DS') # drone segmentation
img, mask = dataset[10]
mask = (mask + 0.2) / 1.2
plt.figure(figsize=(14, 8))
plt.subplot(121)
plt.imshow(img)
plt.subplot(122)
plt.imshow(mask)
plt.suptitle('no augmentation')
plt.show()
from utils.base import np2torch, torch2np
img, mask = np2torch(img), np2torch(mask)
# from utils.augmentation import get_augmentation
augmentation = get_augmentation('strong')
set_global_seed(1)
for i in range(1, 4):
plt.figure(figsize=(14, 8))
plt.subplot(121)
plt.imshow(torch2np(augmentation(img.unsqueeze(0), retain_state=True)).squeeze())
plt.subplot(122)
plt.imshow(torch2np(augmentation(mask.unsqueeze(0), mask_transform=True)).squeeze())
plt.suptitle(f'augmentation test {i}')
plt.show()