VTBench / src /vqvaes /xqgan /diffaug.py
huaweilin's picture
update
14ce5a9
# this file is taken from https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/training/diffaug.py
import math
import torch
import torch.nn.functional as F
def load_png(file_name: str):
from torchvision.io import read_image
return (
read_image(file_name).float().div_(255).mul_(2).sub_(1).unsqueeze(0)
) # to [-1, 1]
def show(tensor): # from [-1, 1]
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_pil_image
if tensor.shape[0] == 1:
tensor = tensor[0]
if tensor.ndim == 3:
to_pil_image(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu()).convert(
"RGB"
).show()
else:
to_pil_image(
make_grid(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu())
).convert("RGB").show()
class DiffAug(object):
def __init__(self, prob=1.0, cutout=0.2): # todo: swin ratio = 0.5, T&XL = 0.2
self.grids = {}
self.prob = abs(prob)
self.using_cutout = prob > 0
self.cutout = cutout
self.img_channels = -1
self.last_blur_radius = -1
self.last_blur_kernel_h = self.last_blur_kernel_w = None
def get_grids(self, B, x, y, dev):
if (B, x, y) in self.grids:
return self.grids[(B, x, y)]
self.grids[(B, x, y)] = ret = torch.meshgrid(
torch.arange(B, dtype=torch.long, device=dev),
torch.arange(x, dtype=torch.long, device=dev),
torch.arange(y, dtype=torch.long, device=dev),
indexing="ij",
)
return ret
def aug(self, BCHW: torch.Tensor, warmup_blur_schedule: float = 0) -> torch.Tensor:
# warmup blurring
if BCHW.dtype != torch.float32:
BCHW = BCHW.float()
if warmup_blur_schedule > 0:
self.img_channels = BCHW.shape[1]
sigma0 = (BCHW.shape[-2] * 0.5) ** 0.5
sigma = sigma0 * warmup_blur_schedule
blur_radius = math.floor(sigma * 3) # 3-sigma is enough for Gaussian
if blur_radius >= 1:
if self.last_blur_radius != blur_radius:
self.last_blur_radius = blur_radius
gaussian = torch.arange(
-blur_radius,
blur_radius + 1,
dtype=torch.float32,
device=BCHW.device,
)
gaussian = gaussian.mul_(1 / sigma).square_().neg_().exp2_()
gaussian.div_(gaussian.sum()) # normalize
self.last_blur_kernel_h = (
gaussian.view(1, 1, 2 * blur_radius + 1, 1)
.repeat(self.img_channels, 1, 1, 1)
.contiguous()
)
self.last_blur_kernel_w = (
gaussian.view(1, 1, 1, 2 * blur_radius + 1)
.repeat(self.img_channels, 1, 1, 1)
.contiguous()
)
BCHW = F.pad(
BCHW,
[blur_radius, blur_radius, blur_radius, blur_radius],
mode="reflect",
)
BCHW = F.conv2d(
input=BCHW,
weight=self.last_blur_kernel_h,
bias=None,
groups=self.img_channels,
)
BCHW = F.conv2d(
input=BCHW,
weight=self.last_blur_kernel_w,
bias=None,
groups=self.img_channels,
)
# BCHW = filter2d(BCHW, f.div_(f.sum())) # no need to specify padding (filter2d will add padding in itself based on filter size)
if self.prob < 1e-6:
return BCHW
trans, color, cut = torch.rand(3) <= self.prob
trans, color, cut = trans.item(), color.item(), cut.item()
B, dev = BCHW.shape[0], BCHW.device
rand01 = torch.rand(7, B, 1, 1, device=dev) if (trans or color or cut) else None
raw_h, raw_w = BCHW.shape[-2:]
if trans:
ratio = 0.125
delta_h = round(raw_h * ratio)
delta_w = round(raw_w * ratio)
translation_h = (
rand01[0].mul(delta_h + delta_h + 1).floor().long() - delta_h
)
translation_w = (
rand01[1].mul(delta_w + delta_w + 1).floor().long() - delta_w
)
# translation_h = torch.randint(-delta_h, delta_h+1, size=(B, 1, 1), device=dev)
# translation_w = torch.randint(-delta_w, delta_w+1, size=(B, 1, 1), device=dev)
grid_B, grid_h, grid_w = self.get_grids(B, raw_h, raw_w, dev)
grid_h = (grid_h + translation_h).add_(1).clamp_(0, raw_h + 1)
grid_w = (grid_w + translation_w).add_(1).clamp_(0, raw_w + 1)
bchw_pad = F.pad(BCHW, [1, 1, 1, 1, 0, 0, 0, 0])
BCHW = (
bchw_pad.permute(0, 2, 3, 1)
.contiguous()[grid_B, grid_h, grid_w]
.permute(0, 3, 1, 2)
.contiguous()
)
if color:
BCHW = BCHW.add(rand01[2].unsqueeze(-1).sub(0.5))
# BCHW.add_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).sub_(0.5))
bchw_mean = BCHW.mean(dim=1, keepdim=True)
BCHW = (
BCHW.sub(bchw_mean).mul(rand01[3].unsqueeze(-1).mul(2)).add_(bchw_mean)
)
# BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).mul_(2)).add_(bchw_mean)
bchw_mean = BCHW.mean(dim=(1, 2, 3), keepdim=True)
BCHW = (
BCHW.sub(bchw_mean)
.mul(rand01[4].unsqueeze(-1).add(0.5))
.add_(bchw_mean)
)
# BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).add_(0.5)).add_(bchw_mean)
if self.using_cutout and cut:
ratio = self.cutout # todo: styleswin ratio = 0.5, T&XL = 0.2
cutout_h = round(raw_h * ratio)
cutout_w = round(raw_w * ratio)
offset_h = rand01[5].mul(raw_h + (1 - cutout_h % 2)).floor().long()
offset_w = rand01[6].mul(raw_w + (1 - cutout_w % 2)).floor().long()
# offset_h = torch.randint(0, raw_h + (1 - cutout_h % 2), size=(B, 1, 1), device=dev)
# offset_w = torch.randint(0, raw_w + (1 - cutout_w % 2), size=(B, 1, 1), device=dev)
grid_B, grid_h, grid_w = self.get_grids(B, cutout_h, cutout_w, dev)
grid_h = (grid_h + offset_h).sub_(cutout_h // 2).clamp(min=0, max=raw_h - 1)
grid_w = (grid_w + offset_w).sub_(cutout_w // 2).clamp(min=0, max=raw_w - 1)
mask = torch.ones(B, raw_h, raw_w, dtype=BCHW.dtype, device=dev)
mask[grid_B, grid_h, grid_w] = 0
BCHW = BCHW.mul(mask.unsqueeze(1))
return BCHW