Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.autograd.profiler as profiler | |
import torch.nn.functional as F | |
import torchvision | |
from torch import nn | |
from scenedino.common import util | |
class SpatialEncoder(nn.Module): | |
""" | |
2D (Spatial/Pixel-aligned/local) image encoder | |
""" | |
def __init__( | |
self, | |
backbone="resnet34", | |
pretrained=True, | |
num_layers=4, | |
index_interp="bilinear", | |
index_padding="border", | |
upsample_interp="bilinear", | |
feature_scale=1.0, | |
use_first_pool=True, | |
norm_type="batch", | |
): | |
""" | |
:param backbone Backbone network. Either custom, in which case | |
model.custom_encoder.ConvEncoder is used OR resnet18/resnet34, in which case the relevant | |
model from torchvision is used | |
:param num_layers number of resnet layers to use, 1-5 | |
:param pretrained Whether to use model weights pretrained on ImageNet | |
:param index_interp Interpolation to use for indexing | |
:param index_padding Padding mode to use for indexing, border | zeros | reflection | |
:param upsample_interp Interpolation to use for upscaling latent code | |
:param feature_scale factor to scale all latent by. Useful (<1) if image | |
is extremely large, to fit in memory. | |
:param use_first_pool if false, skips first maxpool layer to avoid downscaling image | |
features too much (ResNet only) | |
:param norm_type norm type to applied; pretrained model must use batch | |
""" | |
super().__init__() | |
if norm_type != "batch": | |
assert not pretrained | |
self.use_custom_resnet = backbone == "custom" | |
self.feature_scale = feature_scale | |
self.use_first_pool = use_first_pool | |
norm_layer = util.get_norm_layer(norm_type) | |
print("Using torchvision", backbone, "encoder") | |
self.model = getattr(torchvision.models, backbone)( | |
pretrained=pretrained, norm_layer=norm_layer | |
) | |
# Following 2 lines need to be uncommented for older configs | |
self.model.fc = nn.Sequential() | |
self.model.avgpool = nn.Sequential() | |
self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers] | |
self.num_layers = num_layers | |
self.index_interp = index_interp | |
self.index_padding = index_padding | |
self.upsample_interp = upsample_interp | |
self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False) | |
self.register_buffer( | |
"latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False | |
) | |
self.scales = [0] | |
# self.latent (B, L, H, W) | |
def index(self, uv, cam_z=None, image_size=(), z_bounds=None): | |
""" | |
Get pixel-aligned image features at 2D image coordinates | |
:param uv (B, N, 2) image points (x,y) | |
:param cam_z ignored (for compatibility) | |
:param image_size image size, either (width, height) or single int. | |
if not specified, assumes coords are in [-1, 1] | |
:param z_bounds ignored (for compatibility) | |
:return (B, L, N) L is latent size | |
""" | |
with profiler.record_function("encoder_index"): | |
if uv.shape[0] == 1 and self.latent.shape[0] > 1: | |
uv = uv.expand(self.latent.shape[0], -1, -1) | |
with profiler.record_function("encoder_index_pre"): | |
if len(image_size) > 0: | |
if len(image_size) == 1: | |
image_size = (image_size, image_size) | |
scale = self.latent_scaling / image_size | |
uv = uv * scale - 1.0 | |
uv = uv.unsqueeze(2) # (B, N, 1, 2) | |
samples = F.grid_sample( | |
self.latent, | |
uv, | |
align_corners=True, | |
mode=self.index_interp, | |
padding_mode=self.index_padding, | |
) | |
return samples[:, :, :, 0] # (B, C, N) | |
def forward(self, x): | |
""" | |
For extracting ResNet's features. | |
:param x image (B, C, H, W) | |
:return latent (B, latent_size, H, W) | |
""" | |
if self.feature_scale != 1.0: | |
x = F.interpolate( | |
x, | |
scale_factor=self.feature_scale, | |
mode="bilinear" if self.feature_scale > 1.0 else "area", | |
align_corners=True if self.feature_scale > 1.0 else None, | |
recompute_scale_factor=True, | |
) | |
x = x.to(device=self.latent.device) | |
if self.use_custom_resnet: | |
self.latent = self.model(x) | |
else: | |
x = self.model.conv1(x) | |
x = self.model.bn1(x) | |
x = self.model.relu(x) | |
latents = [x] | |
if self.num_layers > 1: | |
if self.use_first_pool: | |
x = self.model.maxpool(x) | |
x = self.model.layer1(x) | |
latents.append(x) | |
if self.num_layers > 2: | |
x = self.model.layer2(x) | |
latents.append(x) | |
if self.num_layers > 3: | |
x = self.model.layer3(x) | |
latents.append(x) | |
if self.num_layers > 4: | |
x = self.model.layer4(x) | |
latents.append(x) | |
self.latents = latents | |
align_corners = None if self.index_interp == "nearest " else True | |
latent_sz = latents[0].shape[-2:] | |
for i in range(len(latents)): | |
latents[i] = F.interpolate( | |
latents[i], | |
latent_sz, | |
mode=self.upsample_interp, | |
align_corners=align_corners, | |
) | |
self.latent = torch.cat(latents, dim=1) | |
self.latent_scaling[0] = self.latent.shape[-1] | |
self.latent_scaling[1] = self.latent.shape[-2] | |
self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0 | |
return [self.latent] | |
def from_conf(cls, conf): | |
return cls( | |
conf.get("backbone"), | |
pretrained=conf.get("pretrained", True), | |
num_layers=conf.get("num_layers", 4), | |
index_interp=conf.get("index_interp", "bilinear"), | |
index_padding=conf.get("index_padding", "border"), | |
upsample_interp=conf.get("upsample_interp", "bilinear"), | |
feature_scale=conf.get("feature_scale", 1.0), | |
use_first_pool=conf.get("use_first_pool", True), | |
) | |