jev-aleks's picture
scenedino init
9e15541
import torch
import torchvision
from torch import nn
class ImageEncoder(nn.Module):
"""
Global image encoder
"""
def __init__(self, backbone="resnet34", pretrained=True, latent_size=128):
"""
:param backbone Backbone network. Assumes it is resnet*
e.g. resnet34 | resnet50
:param num_layers number of resnet layers to use, 1-5
:param pretrained Whether to use model pretrained on ImageNet
"""
super().__init__()
self.model = getattr(torchvision.models, backbone)(pretrained=pretrained)
self.model.fc = nn.Sequential()
self.register_buffer("latent", torch.empty(1, 1), persistent=False)
# self.latent (B, L)
self.latent_size = latent_size
if latent_size != 512:
self.fc = nn.Linear(512, latent_size)
def index(self, uv, cam_z=None, image_size=(), z_bounds=()):
"""
Params ignored (compatibility)
:param uv (B, N, 2) only used for shape
:return latent vector (B, L, N)
"""
return self.latent.unsqueeze(-1).expand(-1, -1, uv.shape[1])
def forward(self, x):
"""
For extracting ResNet's features.
:param x image (B, C, H, W)
:return latent (B, latent_size)
"""
x = x.to(device=self.latent.device)
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
x = torch.flatten(x, 1)
if self.latent_size != 512:
x = self.fc(x)
self.latent = x # (B, latent_size)
return self.latent
@classmethod
def from_conf(cls, conf):
return cls(
conf.get_string("backbone"),
pretrained=conf.get_bool("pretrained", True),
latent_size=conf.get_int("latent_size", 128),
)