|
""" |
|
Adapted from https://github.com/nv-nguyen/template-pose/blob/main/src/utils/augmentation.py |
|
""" |
|
|
|
from torchvision import transforms |
|
from PIL import ImageEnhance, ImageFilter, Image |
|
import numpy as np |
|
import random |
|
import logging |
|
from torchvision.transforms import RandomResizedCrop, ToTensor |
|
|
|
|
|
class PillowRGBAugmentation: |
|
def __init__(self, pillow_fn, p, factor_interval): |
|
self._pillow_fn = pillow_fn |
|
self.p = p |
|
self.factor_interval = factor_interval |
|
|
|
def __call__(self, PIL_image): |
|
if random.random() <= self.p: |
|
factor = random.uniform(*self.factor_interval) |
|
if PIL_image.mode != "RGB": |
|
logging.warning( |
|
f"Error when apply data aug, image mode: {PIL_image.mode}" |
|
) |
|
imgs = imgs.convert("RGB") |
|
logging.warning(f"Success to change to {PIL_image.mode}") |
|
PIL_image = (self._pillow_fn(PIL_image).enhance(factor=factor)).convert( |
|
"RGB" |
|
) |
|
return PIL_image |
|
|
|
|
|
class PillowSharpness(PillowRGBAugmentation): |
|
def __init__( |
|
self, |
|
p=0.3, |
|
factor_interval=(0, 40.0), |
|
): |
|
super().__init__( |
|
pillow_fn=ImageEnhance.Sharpness, |
|
p=p, |
|
factor_interval=factor_interval, |
|
) |
|
|
|
|
|
class PillowContrast(PillowRGBAugmentation): |
|
def __init__( |
|
self, |
|
p=0.3, |
|
factor_interval=(0.5, 1.6), |
|
): |
|
super().__init__( |
|
pillow_fn=ImageEnhance.Contrast, |
|
p=p, |
|
factor_interval=factor_interval, |
|
) |
|
|
|
|
|
class PillowBrightness(PillowRGBAugmentation): |
|
def __init__( |
|
self, |
|
p=0.5, |
|
factor_interval=(0.5, 2.0), |
|
): |
|
super().__init__( |
|
pillow_fn=ImageEnhance.Brightness, |
|
p=p, |
|
factor_interval=factor_interval, |
|
) |
|
|
|
|
|
class PillowColor(PillowRGBAugmentation): |
|
def __init__( |
|
self, |
|
p=1, |
|
factor_interval=(0.0, 20.0), |
|
): |
|
super().__init__( |
|
pillow_fn=ImageEnhance.Color, |
|
p=p, |
|
factor_interval=factor_interval, |
|
) |
|
|
|
|
|
class PillowBlur: |
|
def __init__(self, p=0.4, factor_interval=(1, 3)): |
|
self.p = p |
|
self.k = random.randint(*factor_interval) |
|
|
|
def __call__(self, PIL_image): |
|
if random.random() <= self.p: |
|
PIL_image = PIL_image.filter(ImageFilter.GaussianBlur(self.k)) |
|
return PIL_image |
|
|
|
|
|
class NumpyGaussianNoise: |
|
def __init__(self, p, factor_interval=(0.01, 0.3)): |
|
self.noise_ratio = random.uniform(*factor_interval) |
|
self.p = p |
|
|
|
def __call__(self, img): |
|
if random.random() <= self.p: |
|
img = np.copy(img) |
|
noisesigma = random.uniform(0, self.noise_ratio) |
|
gauss = np.random.normal(0, noisesigma, img.shape) * 255 |
|
img = img + gauss |
|
|
|
img[img > 255] = 255 |
|
img[img < 0] = 0 |
|
return Image.fromarray(np.uint8(img)) |
|
|
|
|
|
class StandardAugmentation: |
|
def __init__( |
|
self, names, brightness, contrast, sharpness, color, blur, gaussian_noise |
|
): |
|
self.brightness = brightness |
|
self.contrast = contrast |
|
self.sharpness = sharpness |
|
self.color = color |
|
self.blur = blur |
|
self.gaussian_noise = gaussian_noise |
|
|
|
|
|
self.names = names.split(",") |
|
self.augmentations = { |
|
"brightness": self.brightness, |
|
"contrast": self.contrast, |
|
"sharpness": self.sharpness, |
|
"color": self.color, |
|
"blur": self.blur, |
|
"gaussian_noise": self.gaussian_noise, |
|
} |
|
|
|
def __call__(self, img): |
|
for name in self.names: |
|
img = self.augmentations[name](img) |
|
return img |
|
|
|
|
|
class GeometricAugmentation: |
|
def __init__( |
|
self, |
|
names, |
|
random_resized_crop, |
|
random_horizontal_flip, |
|
random_vertical_flip, |
|
random_rotation, |
|
): |
|
self.random_resized_crop = random_resized_crop |
|
self.random_horizontal_flip = random_horizontal_flip |
|
self.random_vertical_flip = random_vertical_flip |
|
self.random_rotation = random_rotation |
|
self.names = names.split(",") |
|
|
|
self.augmentations = { |
|
"random_resized_crop": self.random_resized_crop, |
|
"random_horizontal_flip": self.random_horizontal_flip, |
|
"random_vertical_flip": self.random_vertical_flip, |
|
"random_rotation": self.random_rotation, |
|
} |
|
|
|
def __call__(self, img): |
|
for name in self.names: |
|
img = self.augmentations[name](img) |
|
return img |
|
|
|
|
|
class ImageAugmentation: |
|
def __init__( |
|
self, names, clip_transform, standard_augmentation, geometric_augmentation |
|
): |
|
self.clip_transform = clip_transform |
|
self.standard_augmentation = standard_augmentation |
|
self.geometric_augmentation = geometric_augmentation |
|
self.names = names.split(",") |
|
self.transforms = { |
|
"clip_transform": self.clip_transform, |
|
"standard_augmentation": self.standard_augmentation, |
|
"geometric_augmentation": self.geometric_augmentation, |
|
} |
|
print(f"Image augmentation: {self.names}") |
|
|
|
def __call__(self, img): |
|
for name in self.names: |
|
img = self.transforms[name](img) |
|
return img |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import glob |
|
import torchvision.transforms as transforms |
|
from torchvision.utils import save_image |
|
from omegaconf import DictConfig, OmegaConf |
|
from hydra.utils import instantiate |
|
import torch |
|
from PIL import Image |
|
|
|
augmentation_config = OmegaConf.load( |
|
"./configs/dataset/train_transform/augmentation.yaml" |
|
) |
|
augmentation_config.names = "standard_augmentation,geometric_augmentation" |
|
augmentation_transform = instantiate(augmentation_config) |
|
img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg") |
|
|
|
num_try = 20 |
|
num_try_per_image = 8 |
|
num_imgs = 8 |
|
|
|
for idx in range(num_try): |
|
imgs = [] |
|
for idx_img in range(num_imgs): |
|
img = Image.open(img_paths[idx_img]) |
|
for idx_try in range(num_try_per_image): |
|
if idx_try == 0: |
|
imgs.append(ToTensor()(img.resize((224, 224)))) |
|
img_aug = augmentation_transform(img.copy()) |
|
img_aug = ToTensor()(img_aug) |
|
imgs.append(img_aug) |
|
imgs = torch.stack(imgs) |
|
save_image(imgs, f"augmentation_{idx:03d}.png", nrow=9) |
|
|