Spaces:
Running
Running
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import cv2 | |
| from sys import exit | |
| import torch | |
| import torch.nn.functional as F | |
| from lib.utils import ( | |
| grid_positions, | |
| upscale_positions, | |
| downscale_positions, | |
| savefig, | |
| imshow_image | |
| ) | |
| from lib.exceptions import NoGradientError, EmptyTensorError | |
| matplotlib.use('Agg') | |
| def loss_function( | |
| model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None | |
| ): | |
| output = model({ | |
| 'image1': batch['image1'].to(device), | |
| 'image2': batch['image2'].to(device) | |
| }) | |
| loss = torch.tensor(np.array([0], dtype=np.float32), device=device) | |
| has_grad = False | |
| n_valid_samples = 0 | |
| for idx_in_batch in range(batch['image1'].size(0)): | |
| # Network output | |
| dense_features1 = output['dense_features1'][idx_in_batch] | |
| c, h1, w1 = dense_features1.size() | |
| scores1 = output['scores1'][idx_in_batch].view(-1) | |
| dense_features2 = output['dense_features2'][idx_in_batch] | |
| _, h2, w2 = dense_features2.size() | |
| scores2 = output['scores2'][idx_in_batch] | |
| all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0) | |
| descriptors1 = all_descriptors1 | |
| all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0) | |
| fmap_pos1 = grid_positions(h1, w1, device) | |
| pos1 = batch['pos1'][idx_in_batch].to(device) | |
| pos2 = batch['pos2'][idx_in_batch].to(device) | |
| ids = idsAlign(pos1, device, h1, w1) | |
| fmap_pos1 = fmap_pos1[:, ids] | |
| descriptors1 = descriptors1[:, ids] | |
| scores1 = scores1[ids] | |
| # Skip the pair if not enough GT correspondences are available | |
| if ids.size(0) < 128: | |
| continue | |
| # Descriptors at the corresponding positions | |
| fmap_pos2 = torch.round( | |
| downscale_positions(pos2, scaling_steps=scaling_steps) | |
| ).long() | |
| descriptors2 = F.normalize( | |
| dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]], | |
| dim=0 | |
| ) | |
| positive_distance = 2 - 2 * ( | |
| descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2) | |
| ).squeeze() | |
| all_fmap_pos2 = grid_positions(h2, w2, device) | |
| position_distance = torch.max( | |
| torch.abs( | |
| fmap_pos2.unsqueeze(2).float() - | |
| all_fmap_pos2.unsqueeze(1) | |
| ), | |
| dim=0 | |
| )[0] | |
| is_out_of_safe_radius = position_distance > safe_radius | |
| distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2) | |
| negative_distance2 = torch.min( | |
| distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., | |
| dim=1 | |
| )[0] | |
| all_fmap_pos1 = grid_positions(h1, w1, device) | |
| position_distance = torch.max( | |
| torch.abs( | |
| fmap_pos1.unsqueeze(2).float() - | |
| all_fmap_pos1.unsqueeze(1) | |
| ), | |
| dim=0 | |
| )[0] | |
| is_out_of_safe_radius = position_distance > safe_radius | |
| distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1) | |
| negative_distance1 = torch.min( | |
| distance_matrix + (1 - is_out_of_safe_radius.float()) * 10., | |
| dim=1 | |
| )[0] | |
| diff = positive_distance - torch.min( | |
| negative_distance1, negative_distance2 | |
| ) | |
| scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]] | |
| loss = loss + ( | |
| torch.sum(scores1 * scores2 * F.relu(margin + diff)) / | |
| (torch.sum(scores1 * scores2) ) | |
| ) | |
| has_grad = True | |
| n_valid_samples += 1 | |
| if plot and batch['batch_idx'] % batch['log_interval'] == 0: | |
| drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True, plot_path=plot_path) | |
| if not has_grad: | |
| raise NoGradientError | |
| loss = loss / (n_valid_samples ) | |
| return loss | |
| def idsAlign(pos1, device, h1, w1): | |
| pos1D = downscale_positions(pos1, scaling_steps=3) | |
| row = pos1D[0, :] | |
| col = pos1D[1, :] | |
| ids = [] | |
| for i in range(row.shape[0]): | |
| index = ((w1) * (row[i])) + (col[i]) | |
| ids.append(index) | |
| ids = torch.round(torch.Tensor(ids)).long().to(device) | |
| return ids | |
| def drawTraining(image1, image2, pos1, pos2, batch, idx_in_batch, output, save=False, plot_path="train_viz"): | |
| pos1_aux = pos1.cpu().numpy() | |
| pos2_aux = pos2.cpu().numpy() | |
| k = pos1_aux.shape[1] | |
| col = np.random.rand(k, 3) | |
| n_sp = 4 | |
| plt.figure() | |
| plt.subplot(1, n_sp, 1) | |
| im1 = imshow_image( | |
| image1[0].cpu().numpy(), | |
| preprocessing=batch['preprocessing'] | |
| ) | |
| plt.imshow(im1) | |
| plt.scatter( | |
| pos1_aux[1, :], pos1_aux[0, :], | |
| s=0.25**2, c=col, marker=',', alpha=0.5 | |
| ) | |
| plt.axis('off') | |
| plt.subplot(1, n_sp, 2) | |
| plt.imshow( | |
| output['scores1'][idx_in_batch].data.cpu().numpy(), | |
| cmap='Reds' | |
| ) | |
| plt.axis('off') | |
| plt.subplot(1, n_sp, 3) | |
| im2 = imshow_image( | |
| image2[0].cpu().numpy(), | |
| preprocessing=batch['preprocessing'] | |
| ) | |
| plt.imshow(im2) | |
| plt.scatter( | |
| pos2_aux[1, :], pos2_aux[0, :], | |
| s=0.25**2, c=col, marker=',', alpha=0.5 | |
| ) | |
| plt.axis('off') | |
| plt.subplot(1, n_sp, 4) | |
| plt.imshow( | |
| output['scores2'][idx_in_batch].data.cpu().numpy(), | |
| cmap='Reds' | |
| ) | |
| plt.axis('off') | |
| if(save == True): | |
| savefig(plot_path+'/%s.%02d.%02d.%d.png' % ( | |
| 'train' if batch['train'] else 'valid', | |
| batch['epoch_idx'], | |
| batch['batch_idx'] // batch['log_interval'], | |
| idx_in_batch | |
| ), dpi=300) | |
| else: | |
| plt.show() | |
| plt.close() | |
| im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB) | |
| im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB) | |
| for i in range(0, pos1_aux.shape[1], 5): | |
| im1 = cv2.circle(im1, (pos1_aux[1, i], pos1_aux[0, i]), 1, (0, 0, 255), 2) | |
| for i in range(0, pos2_aux.shape[1], 5): | |
| im2 = cv2.circle(im2, (pos2_aux[1, i], pos2_aux[0, i]), 1, (0, 0, 255), 2) | |
| im3 = cv2.hconcat([im1, im2]) | |
| for i in range(0, pos1_aux.shape[1], 5): | |
| im3 = cv2.line(im3, (int(pos1_aux[1, i]), int(pos1_aux[0, i])), (int(pos2_aux[1, i]) + im1.shape[1], int(pos2_aux[0, i])), (0, 255, 0), 1) | |
| if(save == True): | |
| cv2.imwrite(plot_path+'/%s.%02d.%02d.%d.png' % ( | |
| 'train_corr' if batch['train'] else 'valid', | |
| batch['epoch_idx'], | |
| batch['batch_idx'] // batch['log_interval'], | |
| idx_in_batch | |
| ), im3) | |
| else: | |
| cv2.imshow('Image', im3) | |
| cv2.waitKey(0) |