|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv |
|
|
|
|
|
class SimpleImplicitFeaturizer(torch.nn.Module): |
|
|
|
def __init__(self, n_freqs=20): |
|
super().__init__() |
|
self.n_freqs = n_freqs |
|
self.dim_multiplier = 2 |
|
|
|
def forward(self, original_image): |
|
b, c, h, w = original_image.shape |
|
grid_h = torch.linspace(-1, 1, h, device=original_image.device) |
|
grid_w = torch.linspace(-1, 1, w, device=original_image.device) |
|
feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid_h, grid_w])]).unsqueeze(0) |
|
feats = torch.broadcast_to(feats, (b, feats.shape[1], h, w)) |
|
|
|
feat_list = [feats] |
|
feats = torch.cat(feat_list, dim=1).unsqueeze(1) |
|
freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=original_image.device)) \ |
|
.reshape(1, self.n_freqs, 1, 1, 1) |
|
feats = (feats * freqs) |
|
|
|
feats = feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w) |
|
|
|
all_feats = [torch.sin(feats), torch.cos(feats), original_image] |
|
|
|
return torch.cat(all_feats, dim=1) |
|
|
|
|
|
class IFA(torch.nn.Module): |
|
|
|
def __init__(self, feat_dim, num_scales=20): |
|
super().__init__() |
|
self.scales = 2 * torch.exp(torch.tensor(torch.arange(1, num_scales + 1))) |
|
self.feat_dim = feat_dim |
|
self.sin_feats = SimpleImplicitFeaturizer() |
|
self.mlp = nn.Sequential( |
|
nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1), |
|
nn.BatchNorm2d(feat_dim), |
|
nn.LeakyReLU(), |
|
nn.Conv2d(feat_dim, feat_dim, 1), |
|
) |
|
|
|
def forward(self, source, guidance): |
|
b, c, h, w = source.shape |
|
up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest") |
|
assert h == w |
|
lr_cord = torch.linspace(0, h, steps=h, device=source.device) |
|
hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device) |
|
lr_coords = torch.cat([x.unsqueeze(0) for x in torch.meshgrid(lr_cord, lr_cord)], dim=0).unsqueeze(0) |
|
hr_coords = torch.cat([x.unsqueeze(0) for x in torch.meshgrid(hr_cord, hr_cord)], dim=0).unsqueeze(0) |
|
up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest") |
|
coord_diff = up_lr_coords - hr_coords |
|
coord_diff_feats = self.sin_feats(coord_diff) |
|
c2 = coord_diff_feats.shape[1] |
|
bcast_coord_feats = torch.broadcast_to(coord_diff_feats, (b, c2, h * 2, w * 2)) |
|
return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1)) |
|
|
|
|
|
class SAPAModule(nn.Module): |
|
def __init__(self, dim_y, dim_x=None, |
|
up_factor=2, up_kernel_size=5, embedding_dim=64, |
|
qkv_bias=True, norm=nn.LayerNorm): |
|
super().__init__() |
|
dim_x = dim_x if dim_x is not None else dim_y |
|
|
|
self.up_factor = up_factor |
|
self.up_kernel_size = up_kernel_size |
|
self.embedding_dim = embedding_dim |
|
|
|
self.norm_y = norm(dim_y) |
|
self.norm_x = norm(dim_x) |
|
|
|
self.q = nn.Linear(dim_y, embedding_dim, bias=qkv_bias) |
|
self.k = nn.Linear(dim_x, embedding_dim, bias=qkv_bias) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def forward(self, y, x): |
|
y = y.permute(0, 2, 3, 1).contiguous() |
|
x = x.permute(0, 2, 3, 1).contiguous() |
|
y = self.norm_y(y) |
|
x_ = self.norm_x(x) |
|
|
|
q = self.q(y) |
|
k = self.k(x_) |
|
|
|
return self.attention(q, k, x).permute(0, 3, 1, 2).contiguous() |
|
|
|
def attention(self, q, k, v): |
|
from sapa import sim, atn |
|
|
|
attn = F.softmax(sim(q, k, self.up_kernel_size, self.up_factor), dim=-1) |
|
return atn(attn, v, self.up_kernel_size, self.up_factor) |
|
|
|
def _init_weights(self, m): |
|
from timm.models.layers import trunc_normal_ |
|
|
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
elif isinstance(m, nn.Conv2d): |
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
fan_out //= m.groups |
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
|
|
|
|
class SAPAUpsampler(torch.nn.Module): |
|
def __init__(self, dim_x, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.up1 = SAPAModule(dim_x=dim_x, dim_y=3) |
|
self.up2 = SAPAModule(dim_x=dim_x, dim_y=3) |
|
self.up3 = SAPAModule(dim_x=dim_x, dim_y=3) |
|
self.up4 = SAPAModule(dim_x=dim_x, dim_y=3) |
|
|
|
def adapt_guidance(self, source, guidance): |
|
_, _, h, w = source.shape |
|
small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) |
|
return small_guidance |
|
|
|
def forward(self, source, guidance): |
|
source_2 = self.up1(self.adapt_guidance(source, guidance), source) |
|
source_4 = self.up2(self.adapt_guidance(source_2, guidance), source_2) |
|
source_8 = self.up3(self.adapt_guidance(source_4, guidance), source_4) |
|
source_16 = self.up4(self.adapt_guidance(source_8, guidance), source_8) |
|
return source_16 |
|
|
|
|
|
class CarafeUpsampler(torch.nn.Module): |
|
|
|
def __init__(self, dim, kernel_size, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
from mmcv.ops import CARAFEPack |
|
self.up1 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) |
|
self.up2 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) |
|
self.up3 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) |
|
self.up4 = CARAFEPack(dim, up_kernel=3, up_group=1, scale_factor=2) |
|
|
|
def forward(self, source, guidance): |
|
source_2 = self.up1(source) |
|
source_4 = self.up2(source_2) |
|
source_8 = self.up3(source_4) |
|
source_16 = self.up4(source_8) |
|
return source_16 |
|
|
|
|
|
class LayeredResizeConv(torch.nn.Module): |
|
|
|
def __init__(self, dim, kernel_size, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.conv1 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
|
self.conv2 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
|
self.conv3 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
|
self.conv4 = torch.nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
|
|
|
def apply_conv(self, source, guidance, conv, activation): |
|
big_source = F.interpolate(source, scale_factor=2, mode="bilinear") |
|
_, _, h, w = big_source.shape |
|
small_guidance = F.interpolate(guidance, (h, w), mode="bilinear") |
|
output = activation(conv(torch.cat([big_source, small_guidance], dim=1))) |
|
return big_source + output |
|
|
|
def forward(self, source, guidance): |
|
source_2 = self.apply_conv(source, guidance, self.conv1, F.relu) |
|
source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu) |
|
source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu) |
|
source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x) |
|
return source_16 |
|
|
|
|
|
class JBULearnedRange(torch.nn.Module): |
|
|
|
def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3): |
|
super().__init__() |
|
self.scale = scale |
|
self.radius = radius |
|
self.diameter = self.radius * 2 + 1 |
|
|
|
self.guidance_dim = guidance_dim |
|
self.key_dim = key_dim |
|
self.feat_dim = feat_dim |
|
|
|
self.range_temp = nn.Parameter(torch.tensor(0.0)) |
|
self.range_proj = torch.nn.Sequential( |
|
torch.nn.Conv2d(guidance_dim, key_dim, 1, 1), |
|
torch.nn.GELU(), |
|
torch.nn.Dropout2d(.1), |
|
torch.nn.Conv2d(key_dim, key_dim, 1, 1), |
|
) |
|
|
|
self.fixup_proj = torch.nn.Sequential( |
|
torch.nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1), |
|
torch.nn.GELU(), |
|
torch.nn.Dropout2d(.1), |
|
torch.nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1), |
|
) |
|
|
|
self.sigma_spatial = nn.Parameter(torch.tensor(1.0)) |
|
|
|
def get_range_kernel(self, x): |
|
GB, GC, GH, GW = x.shape |
|
proj_x = self.range_proj(x) |
|
proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode='reflect') |
|
queries = torch.nn.Unfold(self.diameter)(proj_x_padded) \ |
|
.reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW)) \ |
|
.permute(0, 1, 3, 4, 2) |
|
pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4) |
|
return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1) |
|
|
|
def get_spatial_kernel(self, device): |
|
dist_range = torch.linspace(-1, 1, self.diameter, device=device) |
|
x, y = torch.meshgrid(dist_range, dist_range) |
|
patch = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0) |
|
return torch.exp(- patch.square().sum(0) / (2 * self.sigma_spatial ** 2)) \ |
|
.reshape(1, self.diameter * self.diameter, 1, 1) |
|
|
|
def forward(self, source, guidance): |
|
GB, GC, GH, GW = guidance.shape |
|
SB, SC, SH, SQ = source.shape |
|
assert (SB == GB) |
|
|
|
spatial_kernel = self.get_spatial_kernel(source.device) |
|
range_kernel = self.get_range_kernel(guidance) |
|
|
|
combined_kernel = range_kernel * spatial_kernel |
|
combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7) |
|
|
|
combined_kernel += .1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1)) |
|
combined_kernel = combined_kernel.permute(0, 2, 3, 1) \ |
|
.reshape(GB, GH, GW, self.diameter, self.diameter) |
|
|
|
hr_source = torch.nn.Upsample((GH, GW), mode='bicubic', align_corners=False)(source) |
|
hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode='reflect') |
|
|
|
|
|
result = AdaptiveConv.apply(hr_source_padded, combined_kernel) |
|
return result |
|
|
|
|
|
class JBUStack(torch.nn.Module): |
|
|
|
def __init__(self, feat_dim, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3) |
|
self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3) |
|
self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3) |
|
self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3) |
|
self.fixup_proj = torch.nn.Sequential( |
|
torch.nn.Dropout2d(0.2), |
|
torch.nn.Conv2d(feat_dim, feat_dim, kernel_size=1)) |
|
|
|
def upsample(self, source, guidance, up): |
|
_, _, h, w = source.shape |
|
small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) |
|
upsampled = up(source, small_guidance) |
|
return upsampled |
|
|
|
def forward(self, source, guidance): |
|
source_2 = self.upsample(source, guidance, self.up1) |
|
source_4 = self.upsample(source_2, guidance, self.up2) |
|
source_8 = self.upsample(source_4, guidance, self.up3) |
|
source_16 = self.upsample(source_8, guidance, self.up4) |
|
return self.fixup_proj(source_16) * 0.1 + source_16 |
|
|
|
|
|
class Bilinear(torch.nn.Module): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, feats, img): |
|
_, _, h, w = img.shape |
|
return F.interpolate(feats, (h, w), mode="bilinear") |
|
|
|
|
|
def get_upsampler(upsampler, dim): |
|
if upsampler == 'bilinear': |
|
return Bilinear() |
|
elif upsampler == 'jbu_stack': |
|
return JBUStack(dim) |
|
elif upsampler == 'resize_conv': |
|
return LayeredResizeConv(dim, 1) |
|
elif upsampler == 'carafe': |
|
return CarafeUpsampler(dim, 1) |
|
elif upsampler == 'sapa': |
|
return SAPAUpsampler(dim_x=dim) |
|
elif upsampler == 'ifa': |
|
return IFA(dim) |
|
else: |
|
raise ValueError(f"Unknown upsampler {upsampler}") |
|
|