Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import torch | |
| import torch.nn as nn | |
| eps = 1e-8 | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def sinkhorn(M, r, c, iteration): | |
| p = torch.softmax(M, dim=-1) | |
| u = torch.ones_like(r) | |
| v = torch.ones_like(c) | |
| for _ in range(iteration): | |
| u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) | |
| v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) | |
| p = p * u.unsqueeze(-1) * v.unsqueeze(-2) | |
| return p | |
| def sink_algorithm(M, dustbin, iteration): | |
| M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) | |
| M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) | |
| r = torch.ones([M.shape[0], M.shape[1] - 1], device=device) | |
| r = torch.cat([r, torch.ones([M.shape[0], 1], device=device) * M.shape[1]], dim=-1) | |
| c = torch.ones([M.shape[0], M.shape[2] - 1], device=device) | |
| c = torch.cat([c, torch.ones([M.shape[0], 1], device=device) * M.shape[2]], dim=-1) | |
| p = sinkhorn(M, r, c, iteration) | |
| return p | |
| def seeding( | |
| nn_index1, | |
| nn_index2, | |
| x1, | |
| x2, | |
| topk, | |
| match_score, | |
| confbar, | |
| nms_radius, | |
| use_mc=True, | |
| test=False, | |
| ): | |
| # apply mutual check before nms | |
| if use_mc: | |
| mask_not_mutual = nn_index2.gather(dim=-1, index=nn_index1) != torch.arange( | |
| nn_index1.shape[1], device=device | |
| ) | |
| match_score[mask_not_mutual] = -1 | |
| # NMS | |
| pos_dismat1 = ( | |
| ( | |
| (x1.norm(p=2, dim=-1) ** 2).unsqueeze_(-1) | |
| + (x1.norm(p=2, dim=-1) ** 2).unsqueeze_(-2) | |
| - 2 * (x1 @ x1.transpose(1, 2)) | |
| ) | |
| .abs_() | |
| .sqrt_() | |
| ) | |
| x2 = x2.gather(index=nn_index1.unsqueeze(-1).expand(-1, -1, 2), dim=1) | |
| pos_dismat2 = ( | |
| ( | |
| (x2.norm(p=2, dim=-1) ** 2).unsqueeze_(-1) | |
| + (x2.norm(p=2, dim=-1) ** 2).unsqueeze_(-2) | |
| - 2 * (x2 @ x2.transpose(1, 2)) | |
| ) | |
| .abs_() | |
| .sqrt_() | |
| ) | |
| radius1, radius2 = nms_radius * pos_dismat1.mean( | |
| dim=(1, 2), keepdim=True | |
| ), nms_radius * pos_dismat2.mean(dim=(1, 2), keepdim=True) | |
| nms_mask = (pos_dismat1 >= radius1) & (pos_dismat2 >= radius2) | |
| mask_not_local_max = ( | |
| match_score.unsqueeze(-1) >= match_score.unsqueeze(-2) | |
| ) | nms_mask | |
| mask_not_local_max = ~(mask_not_local_max.min(dim=-1).values) | |
| match_score[mask_not_local_max] = -1 | |
| # confidence bar | |
| match_score[match_score < confbar] = -1 | |
| mask_survive = match_score > 0 | |
| if test: | |
| topk = min(mask_survive.sum(dim=1)[0] + 2, topk) | |
| _, topindex = torch.topk(match_score, topk, dim=-1) # b*k | |
| seed_index1, seed_index2 = topindex, nn_index1.gather(index=topindex, dim=-1) | |
| return seed_index1, seed_index2 | |
| class PointCN(nn.Module): | |
| def __init__(self, channels, out_channels): | |
| nn.Module.__init__(self) | |
| self.shot_cut = nn.Conv1d(channels, out_channels, kernel_size=1) | |
| self.conv = nn.Sequential( | |
| nn.InstanceNorm1d(channels, eps=1e-3), | |
| nn.SyncBatchNorm(channels), | |
| nn.ReLU(), | |
| nn.Conv1d(channels, channels, kernel_size=1), | |
| nn.InstanceNorm1d(channels, eps=1e-3), | |
| nn.SyncBatchNorm(channels), | |
| nn.ReLU(), | |
| nn.Conv1d(channels, out_channels, kernel_size=1), | |
| ) | |
| def forward(self, x): | |
| return self.conv(x) + self.shot_cut(x) | |
| class attention_propagantion(nn.Module): | |
| def __init__(self, channel, head): | |
| nn.Module.__init__(self) | |
| self.head = head | |
| self.head_dim = channel // head | |
| self.query_filter, self.key_filter, self.value_filter = ( | |
| nn.Conv1d(channel, channel, kernel_size=1), | |
| nn.Conv1d(channel, channel, kernel_size=1), | |
| nn.Conv1d(channel, channel, kernel_size=1), | |
| ) | |
| self.mh_filter = nn.Conv1d(channel, channel, kernel_size=1) | |
| self.cat_filter = nn.Sequential( | |
| nn.Conv1d(2 * channel, 2 * channel, kernel_size=1), | |
| nn.SyncBatchNorm(2 * channel), | |
| nn.ReLU(), | |
| nn.Conv1d(2 * channel, channel, kernel_size=1), | |
| ) | |
| def forward(self, desc1, desc2, weight_v=None): | |
| # desc1(q) attend to desc2(k,v) | |
| batch_size = desc1.shape[0] | |
| query, key, value = ( | |
| self.query_filter(desc1).view(batch_size, self.head, self.head_dim, -1), | |
| self.key_filter(desc2).view(batch_size, self.head, self.head_dim, -1), | |
| self.value_filter(desc2).view(batch_size, self.head, self.head_dim, -1), | |
| ) | |
| if weight_v is not None: | |
| value = value * weight_v.view(batch_size, 1, 1, -1) | |
| score = torch.softmax( | |
| torch.einsum("bhdn,bhdm->bhnm", query, key) / self.head_dim**0.5, dim=-1 | |
| ) | |
| add_value = torch.einsum("bhnm,bhdm->bhdn", score, value).reshape( | |
| batch_size, self.head_dim * self.head, -1 | |
| ) | |
| add_value = self.mh_filter(add_value) | |
| desc1_new = desc1 + self.cat_filter(torch.cat([desc1, add_value], dim=1)) | |
| return desc1_new | |
| class hybrid_block(nn.Module): | |
| def __init__(self, channel, head): | |
| nn.Module.__init__(self) | |
| self.head = head | |
| self.channel = channel | |
| self.attention_block_down = attention_propagantion(channel, head) | |
| self.cluster_filter = nn.Sequential( | |
| nn.Conv1d(2 * channel, 2 * channel, kernel_size=1), | |
| nn.SyncBatchNorm(2 * channel), | |
| nn.ReLU(), | |
| nn.Conv1d(2 * channel, 2 * channel, kernel_size=1), | |
| ) | |
| self.cross_filter = attention_propagantion(channel, head) | |
| self.confidence_filter = PointCN(2 * channel, 1) | |
| self.attention_block_self = attention_propagantion(channel, head) | |
| self.attention_block_up = attention_propagantion(channel, head) | |
| def forward(self, desc1, desc2, seed_index1, seed_index2): | |
| cluster1, cluster2 = desc1.gather( | |
| dim=-1, index=seed_index1.unsqueeze(1).expand(-1, self.channel, -1) | |
| ), desc2.gather( | |
| dim=-1, index=seed_index2.unsqueeze(1).expand(-1, self.channel, -1) | |
| ) | |
| # pooling | |
| cluster1, cluster2 = self.attention_block_down( | |
| cluster1, desc1 | |
| ), self.attention_block_down(cluster2, desc2) | |
| concate_cluster = self.cluster_filter(torch.cat([cluster1, cluster2], dim=1)) | |
| # filtering | |
| cluster1, cluster2 = self.cross_filter( | |
| concate_cluster[:, : self.channel], concate_cluster[:, self.channel :] | |
| ), self.cross_filter( | |
| concate_cluster[:, self.channel :], concate_cluster[:, : self.channel] | |
| ) | |
| cluster1, cluster2 = self.attention_block_self( | |
| cluster1, cluster1 | |
| ), self.attention_block_self(cluster2, cluster2) | |
| # unpooling | |
| seed_weight = self.confidence_filter(torch.cat([cluster1, cluster2], dim=1)) | |
| seed_weight = torch.sigmoid(seed_weight).squeeze(1) | |
| desc1_new, desc2_new = self.attention_block_up( | |
| desc1, cluster1, seed_weight | |
| ), self.attention_block_up(desc2, cluster2, seed_weight) | |
| return desc1_new, desc2_new, seed_weight | |
| class matcher(nn.Module): | |
| def __init__(self, config): | |
| nn.Module.__init__(self) | |
| self.seed_top_k = config.seed_top_k | |
| self.conf_bar = config.conf_bar | |
| self.seed_radius_coe = config.seed_radius_coe | |
| self.use_score_encoding = config.use_score_encoding | |
| self.detach_iter = config.detach_iter | |
| self.seedlayer = config.seedlayer | |
| self.layer_num = config.layer_num | |
| self.sink_iter = config.sink_iter | |
| self.position_encoder = nn.Sequential( | |
| nn.Conv1d(3, 32, kernel_size=1) | |
| if config.use_score_encoding | |
| else nn.Conv1d(2, 32, kernel_size=1), | |
| nn.SyncBatchNorm(32), | |
| nn.ReLU(), | |
| nn.Conv1d(32, 64, kernel_size=1), | |
| nn.SyncBatchNorm(64), | |
| nn.ReLU(), | |
| nn.Conv1d(64, 128, kernel_size=1), | |
| nn.SyncBatchNorm(128), | |
| nn.ReLU(), | |
| nn.Conv1d(128, 256, kernel_size=1), | |
| nn.SyncBatchNorm(256), | |
| nn.ReLU(), | |
| nn.Conv1d(256, config.net_channels, kernel_size=1), | |
| ) | |
| self.hybrid_block = nn.Sequential( | |
| *[ | |
| hybrid_block(config.net_channels, config.head) | |
| for _ in range(config.layer_num) | |
| ] | |
| ) | |
| self.final_project = nn.Conv1d( | |
| config.net_channels, config.net_channels, kernel_size=1 | |
| ) | |
| self.dustbin = nn.Parameter(torch.tensor(1.5, dtype=torch.float32)) | |
| # if reseeding | |
| if len(config.seedlayer) != 1: | |
| self.mid_dustbin = nn.ParameterDict( | |
| { | |
| str(i): nn.Parameter(torch.tensor(2, dtype=torch.float32)) | |
| for i in config.seedlayer[1:] | |
| } | |
| ) | |
| self.mid_final_project = nn.Conv1d( | |
| config.net_channels, config.net_channels, kernel_size=1 | |
| ) | |
| def forward(self, data, test_mode=True): | |
| x1, x2, desc1, desc2 = ( | |
| data["x1"][:, :, :2], | |
| data["x2"][:, :, :2], | |
| data["desc1"], | |
| data["desc2"], | |
| ) | |
| desc1, desc2 = torch.nn.functional.normalize( | |
| desc1, dim=-1 | |
| ), torch.nn.functional.normalize(desc2, dim=-1) | |
| if test_mode: | |
| encode_x1, encode_x2 = data["x1"], data["x2"] | |
| else: | |
| encode_x1, encode_x2 = data["aug_x1"], data["aug_x2"] | |
| # preparation | |
| desc_dismat = (2 - 2 * torch.matmul(desc1, desc2.transpose(1, 2))).sqrt_() | |
| values, nn_index = torch.topk( | |
| desc_dismat, k=2, largest=False, dim=-1, sorted=True | |
| ) | |
| nn_index2 = torch.min(desc_dismat, dim=1).indices.squeeze(1) | |
| inverse_ratio_score, nn_index1 = ( | |
| values[:, :, 1] / values[:, :, 0], | |
| nn_index[:, :, 0], | |
| ) # get inverse score | |
| # initial seeding | |
| seed_index1, seed_index2 = seeding( | |
| nn_index1, | |
| nn_index2, | |
| x1, | |
| x2, | |
| self.seed_top_k[0], | |
| inverse_ratio_score, | |
| self.conf_bar[0], | |
| self.seed_radius_coe, | |
| test=test_mode, | |
| ) | |
| # position encoding | |
| desc1, desc2 = desc1.transpose(1, 2), desc2.transpose(1, 2) | |
| if not self.use_score_encoding: | |
| encode_x1, encode_x2 = encode_x1[:, :, :2], encode_x2[:, :, :2] | |
| encode_x1, encode_x2 = encode_x1.transpose(1, 2), encode_x2.transpose(1, 2) | |
| x1_pos_embedding, x2_pos_embedding = self.position_encoder( | |
| encode_x1 | |
| ), self.position_encoder(encode_x2) | |
| aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2 | |
| seed_weight_tower, mid_p_tower, seed_index_tower, nn_index_tower = ( | |
| [], | |
| [], | |
| [], | |
| [], | |
| ) | |
| seed_index_tower.append(torch.stack([seed_index1, seed_index2], dim=-1)) | |
| nn_index_tower.append(nn_index1) | |
| seed_para_index = 0 | |
| for i in range(self.layer_num): | |
| # mid seeding | |
| if i in self.seedlayer and i != 0: | |
| seed_para_index += 1 | |
| aug_desc1, aug_desc2 = self.mid_final_project( | |
| aug_desc1 | |
| ), self.mid_final_project(aug_desc2) | |
| M = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2) | |
| p = sink_algorithm( | |
| M, self.mid_dustbin[str(i)], self.sink_iter[seed_para_index - 1] | |
| ) | |
| mid_p_tower.append(p) | |
| # rematching with p | |
| values, nn_index = torch.topk(p[:, :-1, :-1], k=1, dim=-1) | |
| nn_index2 = torch.max(p[:, :-1, :-1], dim=1).indices.squeeze(1) | |
| p_match_score, nn_index1 = values[:, :, 0], nn_index[:, :, 0] | |
| # reseeding | |
| seed_index1, seed_index2 = seeding( | |
| nn_index1, | |
| nn_index2, | |
| x1, | |
| x2, | |
| self.seed_top_k[seed_para_index], | |
| p_match_score, | |
| self.conf_bar[seed_para_index], | |
| self.seed_radius_coe, | |
| test=test_mode, | |
| ) | |
| seed_index_tower.append( | |
| torch.stack([seed_index1, seed_index2], dim=-1) | |
| ), nn_index_tower.append(nn_index1) | |
| if not test_mode and data["step"] < self.detach_iter: | |
| aug_desc1, aug_desc2 = aug_desc1.detach(), aug_desc2.detach() | |
| aug_desc1, aug_desc2, seed_weight = self.hybrid_block[i]( | |
| aug_desc1, aug_desc2, seed_index1, seed_index2 | |
| ) | |
| seed_weight_tower.append(seed_weight) | |
| aug_desc1, aug_desc2 = self.final_project(aug_desc1), self.final_project( | |
| aug_desc2 | |
| ) | |
| cmat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2) | |
| p = sink_algorithm(cmat, self.dustbin, self.sink_iter[-1]) | |
| # seed_weight_tower: l*b*k | |
| # seed_index_tower: l*b*k*2 | |
| # nn_index_tower: seed_l*b | |
| return { | |
| "p": p, | |
| "seed_conf": seed_weight_tower, | |
| "seed_index": seed_index_tower, | |
| "mid_p": mid_p_tower, | |
| "nn_index": nn_index_tower, | |
| } | |