featup / featup /layers.py
abreza's picture
add featup codes
7174f3a
raw
history blame
3.19 kB
import torch
def id_conv(dim, strength=.9):
conv = torch.nn.Conv2d(dim, dim, 1, padding="same")
start_w = conv.weight.data
conv.weight.data = torch.nn.Parameter(
torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength))
conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength))
return conv
class ImplicitFeaturizer(torch.nn.Module):
def __init__(self, color_feats=True, n_freqs=10, learn_bias=False, time_feats=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.color_feats = color_feats
self.time_feats = time_feats
self.n_freqs = n_freqs
self.learn_bias = learn_bias
self.dim_multiplier = 2
if self.color_feats:
self.dim_multiplier += 3
if self.time_feats:
self.dim_multiplier += 1
if self.learn_bias:
self.biases = torch.nn.Parameter(torch.randn(2, self.dim_multiplier, n_freqs).to(torch.float32))
def forward(self, original_image):
b, c, h, w = original_image.shape
grid_h = torch.linspace(-1, 1, h, device=original_image.device)
grid_w = torch.linspace(-1, 1, w, device=original_image.device)
feats = torch.cat([t.unsqueeze(0) for t in torch.meshgrid([grid_h, grid_w])]).unsqueeze(0)
feats = torch.broadcast_to(feats, (b, feats.shape[1], h, w))
if self.color_feats:
feat_list = [feats, original_image]
else:
feat_list = [feats]
feats = torch.cat(feat_list, dim=1).unsqueeze(1)
freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=original_image.device)) \
.reshape(1, self.n_freqs, 1, 1, 1)
feats = (feats * freqs)
if self.learn_bias:
sin_feats = feats + self.biases[0].reshape(1, self.n_freqs, self.dim_multiplier, 1, 1)
cos_feats = feats + self.biases[1].reshape(1, self.n_freqs, self.dim_multiplier, 1, 1)
else:
sin_feats = feats
cos_feats = feats
sin_feats = sin_feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w)
cos_feats = cos_feats.reshape(b, self.n_freqs * self.dim_multiplier, h, w)
if self.color_feats:
all_feats = [torch.sin(sin_feats), torch.cos(cos_feats), original_image]
else:
all_feats = [torch.sin(sin_feats), torch.cos(cos_feats)]
return torch.cat(all_feats, dim=1)
class MinMaxScaler(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
c = x.shape[1]
flat_x = x.permute(1, 0, 2, 3).reshape(c, -1)
flat_x_min = flat_x.min(dim=-1).values.reshape(1, c, 1, 1)
flat_x_scale = flat_x.max(dim=-1).values.reshape(1, c, 1, 1) - flat_x_min
return ((x - flat_x_min) / flat_x_scale.clamp_min(0.0001)) - .5
class ChannelNorm(torch.nn.Module):
def __init__(self, dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.norm = torch.nn.LayerNorm(dim)
def forward(self, x):
new_x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return new_x