import torch import torch.autograd.profiler as profiler import torch.nn.functional as F import torchvision from torch import nn from scenedino.common import util class SpatialEncoder(nn.Module): """ 2D (Spatial/Pixel-aligned/local) image encoder """ def __init__( self, backbone="resnet34", pretrained=True, num_layers=4, index_interp="bilinear", index_padding="border", upsample_interp="bilinear", feature_scale=1.0, use_first_pool=True, norm_type="batch", ): """ :param backbone Backbone network. Either custom, in which case model.custom_encoder.ConvEncoder is used OR resnet18/resnet34, in which case the relevant model from torchvision is used :param num_layers number of resnet layers to use, 1-5 :param pretrained Whether to use model weights pretrained on ImageNet :param index_interp Interpolation to use for indexing :param index_padding Padding mode to use for indexing, border | zeros | reflection :param upsample_interp Interpolation to use for upscaling latent code :param feature_scale factor to scale all latent by. Useful (<1) if image is extremely large, to fit in memory. :param use_first_pool if false, skips first maxpool layer to avoid downscaling image features too much (ResNet only) :param norm_type norm type to applied; pretrained model must use batch """ super().__init__() if norm_type != "batch": assert not pretrained self.use_custom_resnet = backbone == "custom" self.feature_scale = feature_scale self.use_first_pool = use_first_pool norm_layer = util.get_norm_layer(norm_type) print("Using torchvision", backbone, "encoder") self.model = getattr(torchvision.models, backbone)( pretrained=pretrained, norm_layer=norm_layer ) # Following 2 lines need to be uncommented for older configs self.model.fc = nn.Sequential() self.model.avgpool = nn.Sequential() self.latent_size = [0, 64, 128, 256, 512, 1024][num_layers] self.num_layers = num_layers self.index_interp = index_interp self.index_padding = index_padding self.upsample_interp = upsample_interp self.register_buffer("latent", torch.empty(1, 1, 1, 1), persistent=False) self.register_buffer( "latent_scaling", torch.empty(2, dtype=torch.float32), persistent=False ) self.scales = [0] # self.latent (B, L, H, W) def index(self, uv, cam_z=None, image_size=(), z_bounds=None): """ Get pixel-aligned image features at 2D image coordinates :param uv (B, N, 2) image points (x,y) :param cam_z ignored (for compatibility) :param image_size image size, either (width, height) or single int. if not specified, assumes coords are in [-1, 1] :param z_bounds ignored (for compatibility) :return (B, L, N) L is latent size """ with profiler.record_function("encoder_index"): if uv.shape[0] == 1 and self.latent.shape[0] > 1: uv = uv.expand(self.latent.shape[0], -1, -1) with profiler.record_function("encoder_index_pre"): if len(image_size) > 0: if len(image_size) == 1: image_size = (image_size, image_size) scale = self.latent_scaling / image_size uv = uv * scale - 1.0 uv = uv.unsqueeze(2) # (B, N, 1, 2) samples = F.grid_sample( self.latent, uv, align_corners=True, mode=self.index_interp, padding_mode=self.index_padding, ) return samples[:, :, :, 0] # (B, C, N) def forward(self, x): """ For extracting ResNet's features. :param x image (B, C, H, W) :return latent (B, latent_size, H, W) """ if self.feature_scale != 1.0: x = F.interpolate( x, scale_factor=self.feature_scale, mode="bilinear" if self.feature_scale > 1.0 else "area", align_corners=True if self.feature_scale > 1.0 else None, recompute_scale_factor=True, ) x = x.to(device=self.latent.device) if self.use_custom_resnet: self.latent = self.model(x) else: x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) latents = [x] if self.num_layers > 1: if self.use_first_pool: x = self.model.maxpool(x) x = self.model.layer1(x) latents.append(x) if self.num_layers > 2: x = self.model.layer2(x) latents.append(x) if self.num_layers > 3: x = self.model.layer3(x) latents.append(x) if self.num_layers > 4: x = self.model.layer4(x) latents.append(x) self.latents = latents align_corners = None if self.index_interp == "nearest " else True latent_sz = latents[0].shape[-2:] for i in range(len(latents)): latents[i] = F.interpolate( latents[i], latent_sz, mode=self.upsample_interp, align_corners=align_corners, ) self.latent = torch.cat(latents, dim=1) self.latent_scaling[0] = self.latent.shape[-1] self.latent_scaling[1] = self.latent.shape[-2] self.latent_scaling = self.latent_scaling / (self.latent_scaling - 1) * 2.0 return [self.latent] @classmethod def from_conf(cls, conf): return cls( conf.get("backbone"), pretrained=conf.get("pretrained", True), num_layers=conf.get("num_layers", 4), index_interp=conf.get("index_interp", "bilinear"), index_padding=conf.get("index_padding", "border"), upsample_interp=conf.get("upsample_interp", "bilinear"), feature_scale=conf.get("feature_scale", 1.0), use_first_pool=conf.get("use_first_pool", True), )