SceneDINO / scenedino /models /backbones /spatial_encoder.py
jev-aleks's picture
scenedino init
9e15541
import torch
import torch.autograd.profiler as profiler
import torch.nn.functional as F
import torchvision
from torch import nn
from scenedino.common import util
class SpatialEncoder(nn.Module):
"""
2D (Spatial/Pixel-aligned/local) image encoder
"""
def __init__(
self,
backbone="resnet34",
pretrained=True,
num_layers=4,
index_interp="bilinear",
index_padding="border",
upsample_interp="bilinear",
feature_scale=1.0,
use_first_pool=True,
norm_type="batch",
):
"""
:param backbone Backbone network. Either custom, in which case
model.custom_encoder.ConvEncoder is used OR resnet18/resnet34, in which case the relevant
model from torchvision is used
:param num_layers number of resnet layers to use, 1-5
:param pretrained Whether to use model weights pretrained on ImageNet
:param index_interp Interpolation to use for indexing
:param index_padding Padding mode to use for indexing, border | zeros | reflection
:param upsample_interp Interpolation to use for upscaling latent code
:param feature_scale factor to scale all latent by. Useful (<1) if image
is extremely large, to fit in memory.
:param use_first_pool if false, skips first maxpool layer to avoid downscaling image
features too much (ResNet only)
:param norm_type norm type to applied; pretrained model must use batch
"""
super().__init__()
if norm_type != "batch":
assert not pretrained
self.use_custom_resnet = backbone == "custom"
self.feature_scale = feature_scale
self.use_first_pool = use_first_pool
norm_layer = util.get_norm_layer(norm_type)
print("Using torchvision", backbone, "encoder")
self.model = getattr(torchvision.models, backbone)(
pretrained=pretrained, norm_layer=norm_layer
)
# Following 2 lines need to be uncommented for older configs
self.model.fc = nn.Sequential()
self.model.avgpool = nn.Sequential()
self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers]
self.num_layers = num_layers
self.index_interp = index_interp
self.index_padding = index_padding
self.upsample_interp = upsample_interp
self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False)
self.register_buffer(
"latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False
)
self.scales = [0]
# self.latent (B, L, H, W)
def index(self, uv, cam_z=None, image_size=(), z_bounds=None):
"""
Get pixel-aligned image features at 2D image coordinates
:param uv (B, N, 2) image points (x,y)
:param cam_z ignored (for compatibility)
:param image_size image size, either (width, height) or single int.
if not specified, assumes coords are in [-1, 1]
:param z_bounds ignored (for compatibility)
:return (B, L, N) L is latent size
"""
with profiler.record_function("encoder_index"):
if uv.shape[0] == 1 and self.latent.shape[0] > 1:
uv = uv.expand(self.latent.shape[0], -1, -1)
with profiler.record_function("encoder_index_pre"):
if len(image_size) > 0:
if len(image_size) == 1:
image_size = (image_size, image_size)
scale = self.latent_scaling / image_size
uv = uv * scale - 1.0
uv = uv.unsqueeze(2) # (B, N, 1, 2)
samples = F.grid_sample(
self.latent,
uv,
align_corners=True,
mode=self.index_interp,
padding_mode=self.index_padding,
)
return samples[:, :, :, 0] # (B, C, N)
def forward(self, x):
"""
For extracting ResNet's features.
:param x image (B, C, H, W)
:return latent (B, latent_size, H, W)
"""
if self.feature_scale != 1.0:
x = F.interpolate(
x,
scale_factor=self.feature_scale,
mode="bilinear" if self.feature_scale > 1.0 else "area",
align_corners=True if self.feature_scale > 1.0 else None,
recompute_scale_factor=True,
)
x = x.to(device=self.latent.device)
if self.use_custom_resnet:
self.latent = self.model(x)
else:
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
latents = [x]
if self.num_layers > 1:
if self.use_first_pool:
x = self.model.maxpool(x)
x = self.model.layer1(x)
latents.append(x)
if self.num_layers > 2:
x = self.model.layer2(x)
latents.append(x)
if self.num_layers > 3:
x = self.model.layer3(x)
latents.append(x)
if self.num_layers > 4:
x = self.model.layer4(x)
latents.append(x)
self.latents = latents
align_corners = None if self.index_interp == "nearest " else True
latent_sz = latents[0].shape[-2:]
for i in range(len(latents)):
latents[i] = F.interpolate(
latents[i],
latent_sz,
mode=self.upsample_interp,
align_corners=align_corners,
)
self.latent = torch.cat(latents, dim=1)
self.latent_scaling[0] = self.latent.shape[-1]
self.latent_scaling[1] = self.latent.shape[-2]
self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0
return [self.latent]
@classmethod
def from_conf(cls, conf):
return cls(
conf.get("backbone"),
pretrained=conf.get("pretrained", True),
num_layers=conf.get("num_layers", 4),
index_interp=conf.get("index_interp", "bilinear"),
index_padding=conf.get("index_padding", "border"),
upsample_interp=conf.get("upsample_interp", "bilinear"),
feature_scale=conf.get("feature_scale", 1.0),
use_first_pool=conf.get("use_first_pool", True),
)