Realcat's picture
add: liftfeat
13760e8
"""
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
"""
import numpy as np
import os
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import math
import cv2
import sys
sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat')
from utils.featurebooster import FeatureBooster
from utils.config import featureboost_config
# from models.model_dfb import LiftFeatModel
# from models.interpolator import InterpolateSparse2d
# from third_party.config import featureboost_config
"""
foundational functions
"""
def simple_nms(scores, radius):
"""Perform non maximum suppression on the heatmap using max-pooling.
This method does not suppress contiguous points that have the same score.
Args:
scores: the score heatmap of size `(B, H, W)`.
radius: an integer scalar, the radius of the NMS window.
"""
def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=radius * 2 + 1, stride=1, padding=radius
)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
for _ in range(2):
supp_mask = max_pool(max_mask.float()) > 0
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == max_pool(supp_scores)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
def top_k_keypoints(keypoints, scores, k):
if k >= len(keypoints):
return keypoints, scores
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
return keypoints[indices], scores
def sample_k_keypoints(keypoints, scores, k):
if k >= len(keypoints):
return keypoints, scores
indices = torch.multinomial(scores, k, replacement=False)
return keypoints[indices], scores[indices]
def soft_argmax_refinement(keypoints, scores, radius: int):
width = 2 * radius + 1
sum_ = torch.nn.functional.avg_pool2d(
scores[:, None], width, 1, radius, divisor_override=1
)
ar = torch.arange(-radius, radius + 1).to(scores)
kernel_x = ar[None].expand(width, -1)[None, None]
dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius)
dy = torch.nn.functional.conv2d(
scores[:, None], kernel_x.transpose(2, 3), padding=radius
)
dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None]
refined_keypoints = []
for i, kpts in enumerate(keypoints):
delta = dydx[i][tuple(kpts.t())]
refined_keypoints.append(kpts.float() + delta)
return refined_keypoints
# Legacy (broken) sampling of the descriptors
def sample_descriptors(keypoints, descriptors, s):
b, c, h, w = descriptors.shape
keypoints = keypoints - s / 2 + 0.5
keypoints /= torch.tensor(
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
).to(
keypoints
)[None]
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
return descriptors
# The original keypoint sampling is incorrect. We patch it here but
# keep the original one above for legacy.
def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
"""Interpolate descriptors at keypoint locations"""
b, c, h, w = descriptors.shape
keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
return descriptors
class UpsampleLayer(nn.Module):
def __init__(self, in_channels):
super().__init__()
# 定义特征提取层,减少通道数同时增加特征提取能力
self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
# 使用BN层
self.bn = nn.BatchNorm2d(in_channels//2)
# 使用LeakyReLU激活函数
self.leaky_relu = nn.LeakyReLU(0.1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
x = self.leaky_relu(self.bn(self.conv(x)))
return x
class KeypointHead(nn.Module):
def __init__(self,in_channels,out_channels):
super().__init__()
self.layer1=BaseLayer(in_channels,32)
self.layer2=BaseLayer(32,32)
self.layer3=BaseLayer(32,64)
self.layer4=BaseLayer(64,64)
self.layer5=BaseLayer(64,128)
self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1)
self.bn=nn.BatchNorm2d(65)
def forward(self,x):
x=self.layer1(x)
x=self.layer2(x)
x=self.layer3(x)
x=self.layer4(x)
x=self.layer5(x)
x=self.bn(self.conv(x))
return x
class DescriptorHead(nn.Module):
def __init__(self,in_channels,out_channels):
super().__init__()
self.layer=nn.Sequential(
BaseLayer(in_channels,32),
BaseLayer(32,32,activation=False),
BaseLayer(32,64,activation=False),
BaseLayer(64,out_channels,activation=False)
)
def forward(self,x):
x=self.layer(x)
# x=nn.functional.softmax(x,dim=1)
return x
class HeatmapHead(nn.Module):
def __init__(self,in_channels,mid_channels,out_channels):
super().__init__()
self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
self.bnHa = nn.BatchNorm2d(mid_channels)
self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bnHb = nn.BatchNorm2d(out_channels)
self.leaky_relu = nn.LeakyReLU(0.1)
def forward(self,x):
x = self.leaky_relu(self.bnHa(self.convHa(x)))
x = self.leaky_relu(self.bnHb(self.convHb(x)))
x = torch.sigmoid(x)
return x
class DepthHead(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.upsampleDa = UpsampleLayer(in_channels)
self.upsampleDb = UpsampleLayer(in_channels//2)
self.upsampleDc = UpsampleLayer(in_channels//4)
self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
self.bnDepa = nn.BatchNorm2d(in_channels//2)
self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1)
self.bnDepb = nn.BatchNorm2d(in_channels//4)
self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1)
self.bnDepc = nn.BatchNorm2d(3)
self.leaky_relu = nn.LeakyReLU(0.1)
def forward(self, x):
x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False)
x1 = self.upsampleDa(x)
x1 = torch.cat([x0,x1],dim=1)
x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1)))
x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False)
x2 = self.upsampleDb(x1)
x2 = torch.cat([x1_0,x2],dim=1)
x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2)))
x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False)
x3 = self.upsampleDc(x2)
x3 = torch.cat([x2_0,x3],dim=1)
x = self.leaky_relu(self.bnDepc(self.convDepc(x3)))
x = F.normalize(x,p=2,dim=1)
return x
class BaseLayer(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True):
super().__init__()
if activation:
self.layer=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
nn.BatchNorm2d(out_channels,affine=False),
nn.ReLU(inplace=True)
)
else:
self.layer=nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
nn.BatchNorm2d(out_channels,affine=False)
)
def forward(self,x):
return self.layer(x)
class LiftFeatSPModel(nn.Module):
default_conf = {
"has_detector": True,
"has_descriptor": True,
"descriptor_dim": 64,
# Inference
"sparse_outputs": True,
"dense_outputs": False,
"nms_radius": 4,
"refinement_radius": 0,
"detection_threshold": 0.005,
"max_num_keypoints": -1,
"max_num_keypoints_val": None,
"force_num_keypoints": False,
"randomize_keypoints_training": False,
"remove_borders": 4,
"legacy_sampling": True, # True to use the old broken sampling
}
def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True):
super().__init__()
self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.descriptor_dim = 64
self.norm = nn.InstanceNorm2d(1)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
c1,c2,c3,c4,c5 = 24,24,64,64,128
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1)
self.upsample4 = UpsampleLayer(c4)
self.upsample5 = UpsampleLayer(c5)
self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1)
self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1)
# detector
self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65)
# descriptor
self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim)
# # heatmap
# self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1)
# depth
self.depth_head = DepthHead(c3)
self.fine_matcher = nn.Sequential(
nn.Linear(128, 512),
nn.BatchNorm1d(512, affine=False),
nn.ReLU(inplace = True),
nn.Linear(512, 512),
nn.BatchNorm1d(512, affine=False),
nn.ReLU(inplace = True),
nn.Linear(512, 512),
nn.BatchNorm1d(512, affine=False),
nn.ReLU(inplace = True),
nn.Linear(512, 512),
nn.BatchNorm1d(512, affine=False),
nn.ReLU(inplace = True),
nn.Linear(512, 64),
)
# feature_booster
self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal)
def feature_extract(self, x):
x1 = self.relu(self.conv1a(x))
x1 = self.relu(self.conv1b(x1))
x1 = self.pool(x1)
x2 = self.relu(self.conv2a(x1))
x2 = self.relu(self.conv2b(x2))
x2 = self.pool(x2)
x3 = self.relu(self.conv3a(x2))
x3 = self.relu(self.conv3b(x3))
x3 = self.pool(x3)
x4 = self.relu(self.conv4a(x3))
x4 = self.relu(self.conv4b(x4))
x4 = self.pool(x4)
x5 = self.relu(self.conv5a(x4))
x5 = self.relu(self.conv5b(x5))
x5 = self.pool(x5)
return x3,x4,x5
def fuse_multi_features(self,x3,x4,x5):
# upsample x5 feature
x5 = self.upsample5(x5)
x4 = torch.cat([x4,x5],dim=1)
x4 = self.conv_fusion45(x4)
# upsample x4 feature
x4 = self.upsample4(x4)
x3 = torch.cat([x3,x4],dim=1)
x = self.conv_fusion34(x3)
return x
def _unfold2d(self, x, ws = 2):
"""
Unfolds tensor in 2D with desired ws (window size) and concat the channels
"""
B, C, H, W = x.shape
x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2)
return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)
def forward1(self, x):
"""
input:
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
return:
feats -> torch.Tensor(B, 64, H/8, W/8) dense local features
keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map
heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map
"""
with torch.no_grad():
x = x.mean(dim=1, keepdim = True)
x = self.norm(x)
x3,x4,x5 = self.feature_extract(x)
# features fusion
x = self.fuse_multi_features(x3,x4,x5)
# keypoint
keypoint_map = self.keypoint_head(x)
# descriptor
des_map = self.descriptor_head(x)
# # heatmap
# heatmap = self.heatmap_head(x)
# import pdb;pdb.set_trace()
# depth
d_feats = self.depth_head(x)
return des_map, keypoint_map, d_feats
# return des_map, keypoint_map, heatmap, d_feats
def forward2(self, descs, kpts, normals):
# import pdb;pdb.set_trace()
normals_feat=self._unfold2d(normals, ws=8)
normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1])
descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1])
kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1])
descs_refine = self.feature_boost(descs_v, kpts_v, normals_v)
return descs_refine
def forward(self,x):
M1,K1,D1=self.forward1(x)
descs_refine=self.forward2(M1,K1,D1)
return descs_refine,M1,K1,D1
if __name__ == "__main__":
img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
img=cv2.resize(img,(800,608))
import pdb;pdb.set_trace()
img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0
img=img.cuda() if torch.cuda.is_available() else img
liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img)
des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats)
print(des_map.shape)