Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import time | |
| eps = 1e-8 | |
| def sinkhorn(M, r, c, iteration): | |
| p = torch.softmax(M, dim=-1) | |
| u = torch.ones_like(r) | |
| v = torch.ones_like(c) | |
| for _ in range(iteration): | |
| u = r / ((p * v.unsqueeze(-2)).sum(-1) + eps) | |
| v = c / ((p * u.unsqueeze(-1)).sum(-2) + eps) | |
| p = p * u.unsqueeze(-1) * v.unsqueeze(-2) | |
| return p | |
| def sink_algorithm(M, dustbin, iteration): | |
| M = torch.cat([M, dustbin.expand([M.shape[0], M.shape[1], 1])], dim=-1) | |
| M = torch.cat([M, dustbin.expand([M.shape[0], 1, M.shape[2]])], dim=-2) | |
| r = torch.ones([M.shape[0], M.shape[1] - 1], device="cuda") | |
| r = torch.cat([r, torch.ones([M.shape[0], 1], device="cuda") * M.shape[1]], dim=-1) | |
| c = torch.ones([M.shape[0], M.shape[2] - 1], device="cuda") | |
| c = torch.cat([c, torch.ones([M.shape[0], 1], device="cuda") * M.shape[2]], dim=-1) | |
| p = sinkhorn(M, r, c, iteration) | |
| return p | |
| class attention_block(nn.Module): | |
| def __init__(self, channels, head, type): | |
| assert type == "self" or type == "cross", "invalid attention type" | |
| nn.Module.__init__(self) | |
| self.head = head | |
| self.type = type | |
| self.head_dim = channels // head | |
| self.query_filter = nn.Conv1d(channels, channels, kernel_size=1) | |
| self.key_filter = nn.Conv1d(channels, channels, kernel_size=1) | |
| self.value_filter = nn.Conv1d(channels, channels, kernel_size=1) | |
| self.attention_filter = nn.Sequential( | |
| nn.Conv1d(2 * channels, 2 * channels, kernel_size=1), | |
| nn.SyncBatchNorm(2 * channels), | |
| nn.ReLU(), | |
| nn.Conv1d(2 * channels, channels, kernel_size=1), | |
| ) | |
| self.mh_filter = nn.Conv1d(channels, channels, kernel_size=1) | |
| def forward(self, fea1, fea2): | |
| batch_size, n, m = fea1.shape[0], fea1.shape[2], fea2.shape[2] | |
| query1, key1, value1 = ( | |
| self.query_filter(fea1).view(batch_size, self.head_dim, self.head, -1), | |
| self.key_filter(fea1).view(batch_size, self.head_dim, self.head, -1), | |
| self.value_filter(fea1).view(batch_size, self.head_dim, self.head, -1), | |
| ) | |
| query2, key2, value2 = ( | |
| self.query_filter(fea2).view(batch_size, self.head_dim, self.head, -1), | |
| self.key_filter(fea2).view(batch_size, self.head_dim, self.head, -1), | |
| self.value_filter(fea2).view(batch_size, self.head_dim, self.head, -1), | |
| ) | |
| if self.type == "self": | |
| score1, score2 = torch.softmax( | |
| torch.einsum("bdhn,bdhm->bhnm", query1, key1) / self.head_dim**0.5, | |
| dim=-1, | |
| ), torch.softmax( | |
| torch.einsum("bdhn,bdhm->bhnm", query2, key2) / self.head_dim**0.5, | |
| dim=-1, | |
| ) | |
| add_value1, add_value2 = torch.einsum( | |
| "bhnm,bdhm->bdhn", score1, value1 | |
| ), torch.einsum("bhnm,bdhm->bdhn", score2, value2) | |
| else: | |
| score1, score2 = torch.softmax( | |
| torch.einsum("bdhn,bdhm->bhnm", query1, key2) / self.head_dim**0.5, | |
| dim=-1, | |
| ), torch.softmax( | |
| torch.einsum("bdhn,bdhm->bhnm", query2, key1) / self.head_dim**0.5, | |
| dim=-1, | |
| ) | |
| add_value1, add_value2 = torch.einsum( | |
| "bhnm,bdhm->bdhn", score1, value2 | |
| ), torch.einsum("bhnm,bdhm->bdhn", score2, value1) | |
| add_value1, add_value2 = self.mh_filter( | |
| add_value1.contiguous().view(batch_size, self.head * self.head_dim, n) | |
| ), self.mh_filter( | |
| add_value2.contiguous().view(batch_size, self.head * self.head_dim, m) | |
| ) | |
| fea11, fea22 = torch.cat([fea1, add_value1], dim=1), torch.cat( | |
| [fea2, add_value2], dim=1 | |
| ) | |
| fea1, fea2 = fea1 + self.attention_filter(fea11), fea2 + self.attention_filter( | |
| fea22 | |
| ) | |
| return fea1, fea2 | |
| class matcher(nn.Module): | |
| def __init__(self, config): | |
| nn.Module.__init__(self) | |
| self.use_score_encoding = config.use_score_encoding | |
| self.layer_num = config.layer_num | |
| self.sink_iter = config.sink_iter | |
| self.position_encoder = nn.Sequential( | |
| nn.Conv1d(3, 32, kernel_size=1) | |
| if config.use_score_encoding | |
| else nn.Conv1d(2, 32, kernel_size=1), | |
| nn.SyncBatchNorm(32), | |
| nn.ReLU(), | |
| nn.Conv1d(32, 64, kernel_size=1), | |
| nn.SyncBatchNorm(64), | |
| nn.ReLU(), | |
| nn.Conv1d(64, 128, kernel_size=1), | |
| nn.SyncBatchNorm(128), | |
| nn.ReLU(), | |
| nn.Conv1d(128, 256, kernel_size=1), | |
| nn.SyncBatchNorm(256), | |
| nn.ReLU(), | |
| nn.Conv1d(256, config.net_channels, kernel_size=1), | |
| ) | |
| self.dustbin = nn.Parameter(torch.tensor(1, dtype=torch.float32, device="cuda")) | |
| self.self_attention_block = nn.Sequential( | |
| *[ | |
| attention_block(config.net_channels, config.head, "self") | |
| for _ in range(config.layer_num) | |
| ] | |
| ) | |
| self.cross_attention_block = nn.Sequential( | |
| *[ | |
| attention_block(config.net_channels, config.head, "cross") | |
| for _ in range(config.layer_num) | |
| ] | |
| ) | |
| self.final_project = nn.Conv1d( | |
| config.net_channels, config.net_channels, kernel_size=1 | |
| ) | |
| def forward(self, data, test_mode=True): | |
| desc1, desc2 = data["desc1"], data["desc2"] | |
| desc1, desc2 = torch.nn.functional.normalize( | |
| desc1, dim=-1 | |
| ), torch.nn.functional.normalize(desc2, dim=-1) | |
| desc1, desc2 = desc1.transpose(1, 2), desc2.transpose(1, 2) | |
| if test_mode: | |
| encode_x1, encode_x2 = data["x1"], data["x2"] | |
| else: | |
| encode_x1, encode_x2 = data["aug_x1"], data["aug_x2"] | |
| if not self.use_score_encoding: | |
| encode_x1, encode_x2 = encode_x1[:, :, :2], encode_x2[:, :, :2] | |
| encode_x1, encode_x2 = encode_x1.transpose(1, 2), encode_x2.transpose(1, 2) | |
| x1_pos_embedding, x2_pos_embedding = self.position_encoder( | |
| encode_x1 | |
| ), self.position_encoder(encode_x2) | |
| aug_desc1, aug_desc2 = x1_pos_embedding + desc1, x2_pos_embedding + desc2 | |
| for i in range(self.layer_num): | |
| aug_desc1, aug_desc2 = self.self_attention_block[i](aug_desc1, aug_desc2) | |
| aug_desc1, aug_desc2 = self.cross_attention_block[i](aug_desc1, aug_desc2) | |
| aug_desc1, aug_desc2 = self.final_project(aug_desc1), self.final_project( | |
| aug_desc2 | |
| ) | |
| desc_mat = torch.matmul(aug_desc1.transpose(1, 2), aug_desc2) | |
| p = sink_algorithm(desc_mat, self.dustbin, self.sink_iter[0]) | |
| return {"p": p} | |