Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,481 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import torch
import torch.autograd.profiler as profiler
import torch.nn.functional as F
import torchvision
from torch import nn
from scenedino.common import util
class SpatialEncoder(nn.Module):
"""
2D (Spatial/Pixel-aligned/local) image encoder
"""
def __init__(
self,
backbone="resnet34",
pretrained=True,
num_layers=4,
index_interp="bilinear",
index_padding="border",
upsample_interp="bilinear",
feature_scale=1.0,
use_first_pool=True,
norm_type="batch",
):
"""
:param backbone Backbone network. Either custom, in which case
model.custom_encoder.ConvEncoder is used OR resnet18/resnet34, in which case the relevant
model from torchvision is used
:param num_layers number of resnet layers to use, 1-5
:param pretrained Whether to use model weights pretrained on ImageNet
:param index_interp Interpolation to use for indexing
:param index_padding Padding mode to use for indexing, border | zeros | reflection
:param upsample_interp Interpolation to use for upscaling latent code
:param feature_scale factor to scale all latent by. Useful (<1) if image
is extremely large, to fit in memory.
:param use_first_pool if false, skips first maxpool layer to avoid downscaling image
features too much (ResNet only)
:param norm_type norm type to applied; pretrained model must use batch
"""
super().__init__()
if norm_type != "batch":
assert not pretrained
self.use_custom_resnet = backbone == "custom"
self.feature_scale = feature_scale
self.use_first_pool = use_first_pool
norm_layer = util.get_norm_layer(norm_type)
print("Using torchvision", backbone, "encoder")
self.model = getattr(torchvision.models, backbone)(
pretrained=pretrained, norm_layer=norm_layer
)
# Following 2 lines need to be uncommented for older configs
self.model.fc = nn.Sequential()
self.model.avgpool = nn.Sequential()
self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]
self.num_layers = num_layers
self.index_interp = index_interp
self.index_padding = index_padding
self.upsample_interp = upsample_interp
self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
self.register_buffer(
"latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
)
self.scales = [0]
# self.latent (B, L, H, W)
def index(self, uv, cam_z=None, image_size=(), z_bounds=None):
"""
Get pixel-aligned image features at 2D image coordinates
:param uv (B, N, 2) image points (x,y)
:param cam_z ignored (for compatibility)
:param image_size image size, either (width, height) or single int.
if not specified, assumes coords are in [-1, 1]
:param z_bounds ignored (for compatibility)
:return (B, L, N) L is latent size
"""
with profiler.record_function("encoder_index"):
if uv.shape[0] == 1 and self.latent.shape[0] > 1:
uv = uv.expand(self.latent.shape[0], -1, -1)
with profiler.record_function("encoder_index_pre"):
if len(image_size) > 0:
if len(image_size) == 1:
image_size = (image_size, image_size)
scale = self.latent_scaling / image_size
uv = uv * scale - 1.0
uv = uv.unsqueeze(2) # (B, N, 1, 2)
samples = F.grid_sample(
self.latent,
uv,
align_corners=True,
mode=self.index_interp,
padding_mode=self.index_padding,
)
return samples[:, :, :, 0] # (B, C, N)
def forward(self, x):
"""
For extracting ResNet's features.
:param x image (B, C, H, W)
:return latent (B, latent_size, H, W)
"""
if self.feature_scale != 1.0:
x = F.interpolate(
x,
scale_factor=self.feature_scale,
mode="bilinear" if self.feature_scale > 1.0 else "area",
align_corners=True if self.feature_scale > 1.0 else None,
recompute_scale_factor=True,
)
x = x.to(device=self.latent.device)
if self.use_custom_resnet:
self.latent = self.model(x)
else:
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
latents = [x]
if self.num_layers > 1:
if self.use_first_pool:
x = self.model.maxpool(x)
x = self.model.layer1(x)
latents.append(x)
if self.num_layers > 2:
x = self.model.layer2(x)
latents.append(x)
if self.num_layers > 3:
x = self.model.layer3(x)
latents.append(x)
if self.num_layers > 4:
x = self.model.layer4(x)
latents.append(x)
self.latents = latents
align_corners = None if self.index_interp == "nearest " else True
latent_sz = latents[0].shape[-2:]
for i in range(len(latents)):
latents[i] = F.interpolate(
latents[i],
latent_sz,
mode=self.upsample_interp,
align_corners=align_corners,
)
self.latent = torch.cat(latents, dim=1)
self.latent_scaling[0] = self.latent.shape[-1]
self.latent_scaling[1] = self.latent.shape[-2]
self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
return [self.latent]
@classmethod
def from_conf(cls, conf):
return cls(
conf.get("backbone"),
pretrained=conf.get("pretrained", True),
num_layers=conf.get("num_layers", 4),
index_interp=conf.get("index_interp", "bilinear"),
index_padding=conf.get("index_padding", "border"),
upsample_interp=conf.get("upsample_interp", "bilinear"),
feature_scale=conf.get("feature_scale", 1.0),
use_first_pool=conf.get("use_first_pool", True),
)
|