Spaces:
Running
Running
| from pathlib import Path | |
| import subprocess | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from scipy.io import loadmat | |
| from ..utils.base_model import BaseModel | |
| from .. import logger | |
| EPS = 1e-6 | |
| class NetVLADLayer(nn.Module): | |
| def __init__(self, input_dim=512, K=64, score_bias=False, intranorm=True): | |
| super().__init__() | |
| self.score_proj = nn.Conv1d( | |
| input_dim, K, kernel_size=1, bias=score_bias | |
| ) | |
| centers = nn.parameter.Parameter(torch.empty([input_dim, K])) | |
| nn.init.xavier_uniform_(centers) | |
| self.register_parameter("centers", centers) | |
| self.intranorm = intranorm | |
| self.output_dim = input_dim * K | |
| def forward(self, x): | |
| b = x.size(0) | |
| scores = self.score_proj(x) | |
| scores = F.softmax(scores, dim=1) | |
| diff = x.unsqueeze(2) - self.centers.unsqueeze(0).unsqueeze(-1) | |
| desc = (scores.unsqueeze(1) * diff).sum(dim=-1) | |
| if self.intranorm: | |
| # From the official MATLAB implementation. | |
| desc = F.normalize(desc, dim=1) | |
| desc = desc.view(b, -1) | |
| desc = F.normalize(desc, dim=1) | |
| return desc | |
| class NetVLAD(BaseModel): | |
| default_conf = {"model_name": "VGG16-NetVLAD-Pitts30K", "whiten": True} | |
| required_inputs = ["image"] | |
| # Models exported using | |
| # https://github.com/uzh-rpg/netvlad_tf_open/blob/master/matlab/net_class2struct.m. | |
| dir_models = { | |
| "VGG16-NetVLAD-Pitts30K": "https://cvg-data.inf.ethz.ch/hloc/netvlad/Pitts30K_struct.mat", | |
| "VGG16-NetVLAD-TokyoTM": "https://cvg-data.inf.ethz.ch/hloc/netvlad/TokyoTM_struct.mat", | |
| } | |
| def _init(self, conf): | |
| assert conf["model_name"] in self.dir_models.keys() | |
| # Download the checkpoint. | |
| checkpoint = Path( | |
| torch.hub.get_dir(), "netvlad", conf["model_name"] + ".mat" | |
| ) | |
| if not checkpoint.exists(): | |
| checkpoint.parent.mkdir(exist_ok=True, parents=True) | |
| link = self.dir_models[conf["model_name"]] | |
| cmd = ["wget", link, "-O", str(checkpoint)] | |
| logger.info(f"Downloading the NetVLAD model with `{cmd}`.") | |
| subprocess.run(cmd, check=True) | |
| # Create the network. | |
| # Remove classification head. | |
| backbone = list(models.vgg16().children())[0] | |
| # Remove last ReLU + MaxPool2d. | |
| self.backbone = nn.Sequential(*list(backbone.children())[:-2]) | |
| self.netvlad = NetVLADLayer() | |
| if conf["whiten"]: | |
| self.whiten = nn.Linear(self.netvlad.output_dim, 4096) | |
| # Parse MATLAB weights using https://github.com/uzh-rpg/netvlad_tf_open | |
| mat = loadmat(checkpoint, struct_as_record=False, squeeze_me=True) | |
| # CNN weights. | |
| for layer, mat_layer in zip( | |
| self.backbone.children(), mat["net"].layers | |
| ): | |
| if isinstance(layer, nn.Conv2d): | |
| w = mat_layer.weights[0] # Shape: S x S x IN x OUT | |
| b = mat_layer.weights[1] # Shape: OUT | |
| # Prepare for PyTorch - enforce float32 and right shape. | |
| # w should have shape: OUT x IN x S x S | |
| # b should have shape: OUT | |
| w = torch.tensor(w).float().permute([3, 2, 0, 1]) | |
| b = torch.tensor(b).float() | |
| # Update layer weights. | |
| layer.weight = nn.Parameter(w) | |
| layer.bias = nn.Parameter(b) | |
| # NetVLAD weights. | |
| score_w = mat["net"].layers[30].weights[0] # D x K | |
| # centers are stored as opposite in official MATLAB code | |
| center_w = -mat["net"].layers[30].weights[1] # D x K | |
| # Prepare for PyTorch - make sure it is float32 and has right shape. | |
| # score_w should have shape K x D x 1 | |
| # center_w should have shape D x K | |
| score_w = torch.tensor(score_w).float().permute([1, 0]).unsqueeze(-1) | |
| center_w = torch.tensor(center_w).float() | |
| # Update layer weights. | |
| self.netvlad.score_proj.weight = nn.Parameter(score_w) | |
| self.netvlad.centers = nn.Parameter(center_w) | |
| # Whitening weights. | |
| if conf["whiten"]: | |
| w = mat["net"].layers[33].weights[0] # Shape: 1 x 1 x IN x OUT | |
| b = mat["net"].layers[33].weights[1] # Shape: OUT | |
| # Prepare for PyTorch - make sure it is float32 and has right shape | |
| w = torch.tensor(w).float().squeeze().permute([1, 0]) # OUT x IN | |
| b = torch.tensor(b.squeeze()).float() # Shape: OUT | |
| # Update layer weights. | |
| self.whiten.weight = nn.Parameter(w) | |
| self.whiten.bias = nn.Parameter(b) | |
| # Preprocessing parameters. | |
| self.preprocess = { | |
| "mean": mat["net"].meta.normalization.averageImage[0, 0], | |
| "std": np.array([1, 1, 1], dtype=np.float32), | |
| } | |
| def _forward(self, data): | |
| image = data["image"] | |
| assert image.shape[1] == 3 | |
| assert image.min() >= -EPS and image.max() <= 1 + EPS | |
| image = torch.clamp(image * 255, 0.0, 255.0) # Input should be 0-255. | |
| mean = self.preprocess["mean"] | |
| std = self.preprocess["std"] | |
| image = image - image.new_tensor(mean).view(1, -1, 1, 1) | |
| image = image / image.new_tensor(std).view(1, -1, 1, 1) | |
| # Feature extraction. | |
| descriptors = self.backbone(image) | |
| b, c, _, _ = descriptors.size() | |
| descriptors = descriptors.view(b, c, -1) | |
| # NetVLAD layer. | |
| descriptors = F.normalize(descriptors, dim=1) # Pre-normalization. | |
| desc = self.netvlad(descriptors) | |
| # Whiten if needed. | |
| if hasattr(self, "whiten"): | |
| desc = self.whiten(desc) | |
| desc = F.normalize(desc, dim=1) # Final L2 normalization. | |
| return {"global_descriptor": desc} | |