Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchvision | |
from torch import nn | |
class ImageEncoder(nn.Module): | |
""" | |
Global image encoder | |
""" | |
def __init__(self, backbone="resnet34", pretrained=True, latent_size=128): | |
""" | |
:param backbone Backbone network. Assumes it is resnet* | |
e.g. resnet34 | resnet50 | |
:param num_layers number of resnet layers to use, 1-5 | |
:param pretrained Whether to use model pretrained on ImageNet | |
""" | |
super().__init__() | |
self.model = getattr(torchvision.models, backbone)(pretrained=pretrained) | |
self.model.fc = nn.Sequential() | |
self.register_buffer("latent", torch.empty(1, 1), persistent=False) | |
# self.latent (B, L) | |
self.latent_size = latent_size | |
if latent_size != 512: | |
self.fc = nn.Linear(512, latent_size) | |
def index(self, uv, cam_z=None, image_size=(), z_bounds=()): | |
""" | |
Params ignored (compatibility) | |
:param uv (B, N, 2) only used for shape | |
:return latent vector (B, L, N) | |
""" | |
return self.latent.unsqueeze(-1).expand(-1, -1, uv.shape[1]) | |
def forward(self, x): | |
""" | |
For extracting ResNet's features. | |
:param x image (B, C, H, W) | |
:return latent (B, latent_size) | |
""" | |
x = x.to(device=self.latent.device) | |
x = self.model.conv1(x) | |
x = self.model.bn1(x) | |
x = self.model.relu(x) | |
x = self.model.maxpool(x) | |
x = self.model.layer1(x) | |
x = self.model.layer2(x) | |
x = self.model.layer3(x) | |
x = self.model.layer4(x) | |
x = self.model.avgpool(x) | |
x = torch.flatten(x, 1) | |
if self.latent_size != 512: | |
x = self.fc(x) | |
self.latent = x # (B, latent_size) | |
return self.latent | |
def from_conf(cls, conf): | |
return cls( | |
conf.get_string("backbone"), | |
pretrained=conf.get_bool("pretrained", True), | |
latent_size=conf.get_int("latent_size", 128), | |
) | |