Spaces:
Running
Running
from typing import List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def MLP(channels: List[int], do_bn: bool = False) -> nn.Module: | |
""" Multi-layer perceptron """ | |
n = len(channels) | |
layers = [] | |
for i in range(1, n): | |
layers.append(nn.Linear(channels[i - 1], channels[i])) | |
if i < (n-1): | |
if do_bn: | |
layers.append(nn.BatchNorm1d(channels[i])) | |
layers.append(nn.ReLU()) | |
return nn.Sequential(*layers) | |
def MLP_no_ReLU(channels: List[int], do_bn: bool = False) -> nn.Module: | |
""" Multi-layer perceptron """ | |
n = len(channels) | |
layers = [] | |
for i in range(1, n): | |
layers.append(nn.Linear(channels[i - 1], channels[i])) | |
if i < (n-1): | |
if do_bn: | |
layers.append(nn.BatchNorm1d(channels[i])) | |
return nn.Sequential(*layers) | |
class KeypointEncoder(nn.Module): | |
""" Encoding of geometric properties using MLP """ | |
def __init__(self, keypoint_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None: | |
super().__init__() | |
self.encoder = MLP([keypoint_dim] + layers + [feature_dim]) | |
self.use_dropout = dropout | |
self.dropout = nn.Dropout(p=p) | |
def forward(self, kpts): | |
if self.use_dropout: | |
return self.dropout(self.encoder(kpts)) | |
return self.encoder(kpts) | |
class NormalEncoder(nn.Module): | |
""" Encoding of geometric properties using MLP """ | |
def __init__(self, normal_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None: | |
super().__init__() | |
self.encoder = MLP_no_ReLU([normal_dim] + layers + [feature_dim]) | |
self.use_dropout = dropout | |
self.dropout = nn.Dropout(p=p) | |
def forward(self, kpts): | |
if self.use_dropout: | |
return self.dropout(self.encoder(kpts)) | |
return self.encoder(kpts) | |
class DescriptorEncoder(nn.Module): | |
""" Encoding of visual descriptor using MLP """ | |
def __init__(self, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None: | |
super().__init__() | |
self.encoder = MLP([feature_dim] + layers + [feature_dim]) | |
self.use_dropout = dropout | |
self.dropout = nn.Dropout(p=p) | |
def forward(self, descs): | |
residual = descs | |
if self.use_dropout: | |
return residual + self.dropout(self.encoder(descs)) | |
return residual + self.encoder(descs) | |
class AFTAttention(nn.Module): | |
""" Attention-free attention """ | |
def __init__(self, d_model: int, dropout: bool = False, p: float = 0.1) -> None: | |
super().__init__() | |
self.dim = d_model | |
self.query = nn.Linear(d_model, d_model) | |
self.key = nn.Linear(d_model, d_model) | |
self.value = nn.Linear(d_model, d_model) | |
self.proj = nn.Linear(d_model, d_model) | |
# self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
self.use_dropout = dropout | |
self.dropout = nn.Dropout(p=p) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
residual = x | |
q = self.query(x) | |
k = self.key(x) | |
v = self.value(x) | |
# q = torch.sigmoid(q) | |
k = k.T | |
k = torch.softmax(k, dim=-1) | |
k = k.T | |
kv = (k * v).sum(dim=-2, keepdim=True) | |
x = q * kv | |
x = self.proj(x) | |
if self.use_dropout: | |
x = self.dropout(x) | |
x += residual | |
# x = self.layer_norm(x) | |
return x | |
class PositionwiseFeedForward(nn.Module): | |
def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1) -> None: | |
super().__init__() | |
self.mlp = MLP([feature_dim, feature_dim*2, feature_dim]) | |
# self.layer_norm = nn.LayerNorm(feature_dim, eps=1e-6) | |
self.use_dropout = dropout | |
self.dropout = nn.Dropout(p=p) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
residual = x | |
x = self.mlp(x) | |
if self.use_dropout: | |
x = self.dropout(x) | |
x += residual | |
# x = self.layer_norm(x) | |
return x | |
class AttentionalLayer(nn.Module): | |
def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1): | |
super().__init__() | |
self.attn = AFTAttention(feature_dim, dropout=dropout, p=p) | |
self.ffn = PositionwiseFeedForward(feature_dim, dropout=dropout, p=p) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# import pdb;pdb.set_trace() | |
x = self.attn(x) | |
x = self.ffn(x) | |
return x | |
class AttentionalNN(nn.Module): | |
def __init__(self, feature_dim: int, layer_num: int, dropout: bool = False, p: float = 0.1) -> None: | |
super().__init__() | |
self.layers = nn.ModuleList([ | |
AttentionalLayer(feature_dim, dropout=dropout, p=p) | |
for _ in range(layer_num)]) | |
def forward(self, desc: torch.Tensor) -> torch.Tensor: | |
for layer in self.layers: | |
desc = layer(desc) | |
return desc | |
class FeatureBooster(nn.Module): | |
default_config = { | |
'descriptor_dim': 128, | |
'keypoint_encoder': [32, 64, 128], | |
'Attentional_layers': 3, | |
'last_activation': 'relu', | |
'l2_normalization': True, | |
'output_dim': 128 | |
} | |
def __init__(self, config, dropout=False, p=0.1, use_kenc=True, use_normal=True, use_cross=True): | |
super().__init__() | |
self.config = {**self.default_config, **config} | |
self.use_kenc = use_kenc | |
self.use_cross = use_cross | |
self.use_normal = use_normal | |
if use_kenc: | |
self.kenc = KeypointEncoder(self.config['keypoint_dim'], self.config['descriptor_dim'], self.config['keypoint_encoder'], dropout=dropout) | |
if use_normal: | |
self.nenc = NormalEncoder(self.config['normal_dim'], self.config['descriptor_dim'], self.config['normal_encoder'], dropout=dropout) | |
if self.config.get('descriptor_encoder', False): | |
self.denc = DescriptorEncoder(self.config['descriptor_dim'], self.config['descriptor_encoder'], dropout=dropout) | |
else: | |
self.denc = None | |
if self.use_cross: | |
self.attn_proj = AttentionalNN(feature_dim=self.config['descriptor_dim'], layer_num=self.config['Attentional_layers'], dropout=dropout) | |
# self.final_proj = nn.Linear(self.config['descriptor_dim'], self.config['output_dim']) | |
self.use_dropout = dropout | |
self.dropout = nn.Dropout(p=p) | |
# self.layer_norm = nn.LayerNorm(self.config['descriptor_dim'], eps=1e-6) | |
if self.config.get('last_activation', False): | |
if self.config['last_activation'].lower() == 'relu': | |
self.last_activation = nn.ReLU() | |
elif self.config['last_activation'].lower() == 'sigmoid': | |
self.last_activation = nn.Sigmoid() | |
elif self.config['last_activation'].lower() == 'tanh': | |
self.last_activation = nn.Tanh() | |
else: | |
raise Exception('Not supported activation "%s".' % self.config['last_activation']) | |
else: | |
self.last_activation = None | |
def forward(self, desc, kpts, normals): | |
# import pdb;pdb.set_trace() | |
## Self boosting | |
# Descriptor MLP encoder | |
if self.denc is not None: | |
desc = self.denc(desc) | |
# Geometric MLP encoder | |
if self.use_kenc: | |
desc = desc + self.kenc(kpts) | |
if self.use_dropout: | |
desc = self.dropout(desc) | |
# 法向量特征 encoder | |
if self.use_normal: | |
desc = desc + self.nenc(normals) | |
if self.use_dropout: | |
desc = self.dropout(desc) | |
## Cross boosting | |
# Multi-layer Transformer network. | |
if self.use_cross: | |
# desc = self.attn_proj(self.layer_norm(desc)) | |
desc = self.attn_proj(desc) | |
## Post processing | |
# Final MLP projection | |
# desc = self.final_proj(desc) | |
if self.last_activation is not None: | |
desc = self.last_activation(desc) | |
# L2 normalization | |
if self.config['l2_normalization']: | |
desc = F.normalize(desc, dim=-1) | |
return desc | |
if __name__ == "__main__": | |
from config import t1_featureboost_config | |
fb_net = FeatureBooster(t1_featureboost_config) | |
descs=torch.randn([1900,64]) | |
kpts=torch.randn([1900,65]) | |
normals=torch.randn([1900,3]) | |
import pdb;pdb.set_trace() | |
descs_refine=fb_net(descs,kpts,normals) | |
print(descs_refine.shape) | |