import torch import torchvision from torch import nn class ImageEncoder(nn.Module): """ Global image encoder """ def __init__(self, backbone="resnet34", pretrained=True, latent_size=128): """ :param backbone Backbone network. Assumes it is resnet* e.g. resnet34 | resnet50 :param num_layers number of resnet layers to use, 1-5 :param pretrained Whether to use model pretrained on ImageNet """ super().__init__() self.model = getattr(torchvision.models, backbone)(pretrained=pretrained) self.model.fc = nn.Sequential() self.register_buffer("latent", torch.empty(1, 1), persistent=False) # self.latent (B, L) self.latent_size = latent_size if latent_size != 512: self.fc = nn.Linear(512, latent_size) def index(self, uv, cam_z=None, image_size=(), z_bounds=()): """ Params ignored (compatibility) :param uv (B, N, 2) only used for shape :return latent vector (B, L, N) """ return self.latent.unsqueeze(-1).expand(-1, -1, uv.shape[1]) def forward(self, x): """ For extracting ResNet's features. :param x image (B, C, H, W) :return latent (B, latent_size) """ x = x.to(device=self.latent.device) x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) x = self.model.avgpool(x) x = torch.flatten(x, 1) if self.latent_size != 512: x = self.fc(x) self.latent = x # (B, latent_size) return self.latent @classmethod def from_conf(cls, conf): return cls( conf.get_string("backbone"), pretrained=conf.get_bool("pretrained", True), latent_size=conf.get_int("latent_size", 128), )