Realcat's picture
add: liftfeat
13760e8
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
def MLP(channels: List[int], do_bn: bool = False) -> nn.Module:
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(nn.Linear(channels[i - 1], channels[i]))
if i < (n-1):
if do_bn:
layers.append(nn.BatchNorm1d(channels[i]))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def MLP_no_ReLU(channels: List[int], do_bn: bool = False) -> nn.Module:
""" Multi-layer perceptron """
n = len(channels)
layers = []
for i in range(1, n):
layers.append(nn.Linear(channels[i - 1], channels[i]))
if i < (n-1):
if do_bn:
layers.append(nn.BatchNorm1d(channels[i]))
return nn.Sequential(*layers)
class KeypointEncoder(nn.Module):
""" Encoding of geometric properties using MLP """
def __init__(self, keypoint_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
super().__init__()
self.encoder = MLP([keypoint_dim] + layers + [feature_dim])
self.use_dropout = dropout
self.dropout = nn.Dropout(p=p)
def forward(self, kpts):
if self.use_dropout:
return self.dropout(self.encoder(kpts))
return self.encoder(kpts)
class NormalEncoder(nn.Module):
""" Encoding of geometric properties using MLP """
def __init__(self, normal_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
super().__init__()
self.encoder = MLP_no_ReLU([normal_dim] + layers + [feature_dim])
self.use_dropout = dropout
self.dropout = nn.Dropout(p=p)
def forward(self, kpts):
if self.use_dropout:
return self.dropout(self.encoder(kpts))
return self.encoder(kpts)
class DescriptorEncoder(nn.Module):
""" Encoding of visual descriptor using MLP """
def __init__(self, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
super().__init__()
self.encoder = MLP([feature_dim] + layers + [feature_dim])
self.use_dropout = dropout
self.dropout = nn.Dropout(p=p)
def forward(self, descs):
residual = descs
if self.use_dropout:
return residual + self.dropout(self.encoder(descs))
return residual + self.encoder(descs)
class AFTAttention(nn.Module):
""" Attention-free attention """
def __init__(self, d_model: int, dropout: bool = False, p: float = 0.1) -> None:
super().__init__()
self.dim = d_model
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.proj = nn.Linear(d_model, d_model)
# self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.use_dropout = dropout
self.dropout = nn.Dropout(p=p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
q = self.query(x)
k = self.key(x)
v = self.value(x)
# q = torch.sigmoid(q)
k = k.T
k = torch.softmax(k, dim=-1)
k = k.T
kv = (k * v).sum(dim=-2, keepdim=True)
x = q * kv
x = self.proj(x)
if self.use_dropout:
x = self.dropout(x)
x += residual
# x = self.layer_norm(x)
return x
class PositionwiseFeedForward(nn.Module):
def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1) -> None:
super().__init__()
self.mlp = MLP([feature_dim, feature_dim*2, feature_dim])
# self.layer_norm = nn.LayerNorm(feature_dim, eps=1e-6)
self.use_dropout = dropout
self.dropout = nn.Dropout(p=p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.mlp(x)
if self.use_dropout:
x = self.dropout(x)
x += residual
# x = self.layer_norm(x)
return x
class AttentionalLayer(nn.Module):
def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1):
super().__init__()
self.attn = AFTAttention(feature_dim, dropout=dropout, p=p)
self.ffn = PositionwiseFeedForward(feature_dim, dropout=dropout, p=p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# import pdb;pdb.set_trace()
x = self.attn(x)
x = self.ffn(x)
return x
class AttentionalNN(nn.Module):
def __init__(self, feature_dim: int, layer_num: int, dropout: bool = False, p: float = 0.1) -> None:
super().__init__()
self.layers = nn.ModuleList([
AttentionalLayer(feature_dim, dropout=dropout, p=p)
for _ in range(layer_num)])
def forward(self, desc: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
desc = layer(desc)
return desc
class FeatureBooster(nn.Module):
default_config = {
'descriptor_dim': 128,
'keypoint_encoder': [32, 64, 128],
'Attentional_layers': 3,
'last_activation': 'relu',
'l2_normalization': True,
'output_dim': 128
}
def __init__(self, config, dropout=False, p=0.1, use_kenc=True, use_normal=True, use_cross=True):
super().__init__()
self.config = {**self.default_config, **config}
self.use_kenc = use_kenc
self.use_cross = use_cross
self.use_normal = use_normal
if use_kenc:
self.kenc = KeypointEncoder(self.config['keypoint_dim'], self.config['descriptor_dim'], self.config['keypoint_encoder'], dropout=dropout)
if use_normal:
self.nenc = NormalEncoder(self.config['normal_dim'], self.config['descriptor_dim'], self.config['normal_encoder'], dropout=dropout)
if self.config.get('descriptor_encoder', False):
self.denc = DescriptorEncoder(self.config['descriptor_dim'], self.config['descriptor_encoder'], dropout=dropout)
else:
self.denc = None
if self.use_cross:
self.attn_proj = AttentionalNN(feature_dim=self.config['descriptor_dim'], layer_num=self.config['Attentional_layers'], dropout=dropout)
# self.final_proj = nn.Linear(self.config['descriptor_dim'], self.config['output_dim'])
self.use_dropout = dropout
self.dropout = nn.Dropout(p=p)
# self.layer_norm = nn.LayerNorm(self.config['descriptor_dim'], eps=1e-6)
if self.config.get('last_activation', False):
if self.config['last_activation'].lower() == 'relu':
self.last_activation = nn.ReLU()
elif self.config['last_activation'].lower() == 'sigmoid':
self.last_activation = nn.Sigmoid()
elif self.config['last_activation'].lower() == 'tanh':
self.last_activation = nn.Tanh()
else:
raise Exception('Not supported activation "%s".' % self.config['last_activation'])
else:
self.last_activation = None
def forward(self, desc, kpts, normals):
# import pdb;pdb.set_trace()
## Self boosting
# Descriptor MLP encoder
if self.denc is not None:
desc = self.denc(desc)
# Geometric MLP encoder
if self.use_kenc:
desc = desc + self.kenc(kpts)
if self.use_dropout:
desc = self.dropout(desc)
# 法向量特征 encoder
if self.use_normal:
desc = desc + self.nenc(normals)
if self.use_dropout:
desc = self.dropout(desc)
## Cross boosting
# Multi-layer Transformer network.
if self.use_cross:
# desc = self.attn_proj(self.layer_norm(desc))
desc = self.attn_proj(desc)
## Post processing
# Final MLP projection
# desc = self.final_proj(desc)
if self.last_activation is not None:
desc = self.last_activation(desc)
# L2 normalization
if self.config['l2_normalization']:
desc = F.normalize(desc, dim=-1)
return desc
if __name__ == "__main__":
from config import t1_featureboost_config
fb_net = FeatureBooster(t1_featureboost_config)
descs=torch.randn([1900,64])
kpts=torch.randn([1900,65])
normals=torch.randn([1900,3])
import pdb;pdb.set_trace()
descs_refine=fb_net(descs,kpts,normals)
print(descs_refine.shape)