Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
from os.path import join | |
import sys | |
import numpy as np | |
import cv2 | |
import torch | |
from matplotlib import pyplot as plt | |
from tqdm import tqdm | |
from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT | |
from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches | |
from line_matching.two_view_pipeline import TwoViewPipeline | |
from scalelsd.base import show, WireframeGraph | |
from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography | |
from kornia.geometry import warp_perspective,transform_points | |
class HADConfig: | |
num_iter = 1 | |
valid_border_margin = 3 | |
translation = True | |
rotation = True | |
scale = True | |
perspective = True | |
scaling_amplitude = 0.2 | |
perspective_amplitude_x = 0.2 | |
perspective_amplitude_y = 0.2 | |
allow_artifacts = False | |
patch_ratio = 0.85 | |
had_cfg = HADConfig() | |
def sample_homographics(height, width): | |
def scale_homography(H, stride): | |
H_scaled = H.clone() | |
H_scaled[:, :, 2, :2] *= stride | |
H_scaled[:, :, :2, 2] /= stride | |
return H_scaled | |
homographic = sample_homography( | |
shape = (height, width), | |
perspective = had_cfg.perspective, | |
scaling = had_cfg.scale, | |
rotation = had_cfg.rotation, | |
translation = had_cfg.translation, | |
scaling_amplitude = had_cfg.scaling_amplitude, | |
perspective_amplitude_x = had_cfg.perspective_amplitude_x, | |
perspective_amplitude_y = had_cfg.perspective_amplitude_y, | |
patch_ratio = had_cfg.patch_ratio, | |
allow_artifacts = False | |
)[0] | |
homographic = torch.from_numpy(homographic[None]).float().cuda() | |
homographic_inv = torch.inverse(homographic) | |
H = { | |
'h.1': homographic, | |
'ih.1': homographic_inv, | |
} | |
return H | |
def trans_image_with_homograpy(image): | |
h, w = image.shape[:2] | |
H = sample_homographics(height=h, width=w) | |
image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w)) | |
image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8) | |
plt.imshow(image_warped_) | |
plt.show() | |
return image_warped_ | |
def main(): | |
# Parse input parameters | |
parser = argparse.ArgumentParser( | |
prog='GlueStick Demo', | |
description='Demo app to show the point and line matches obtained by GlueStick') | |
parser.add_argument('-img1', default='assets/figs/sa_1119229.jpg') | |
parser.add_argument('-img2', default=None) | |
parser.add_argument('--max_pts', type=int, default=1000) | |
parser.add_argument('--max_lines', type=int, default=300) | |
parser.add_argument('--model', type=str, default='models/paper-sa1b-997pkgs-model.pt') | |
args = parser.parse_args() | |
# important | |
if args.img1 is None and args.img2 is None: | |
raise ValueError("Input at least one path of image1 or image2") | |
# Evaluation config | |
conf = { | |
'name': 'two_view_pipeline', | |
'use_lines': True, | |
'extractor': { | |
'name': 'wireframe', | |
'sp_params': { | |
'force_num_keypoints': False, | |
'max_num_keypoints': args.max_pts, | |
}, | |
'wireframe_params': { | |
'merge_points': True, | |
'merge_line_endpoints': True, | |
# 'merge_line_endpoints': False, | |
}, | |
'max_n_lines': args.max_lines, | |
}, | |
'matcher': { | |
'name': 'gluestick', | |
'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), | |
'trainable': False, | |
}, | |
'ground_truth': { | |
'from_pose_depth': False, | |
} | |
} | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
pipeline_model = TwoViewPipeline(conf).to(device).eval() | |
pipeline_model.extractor.update_conf(None) | |
saveto = f'temp_output/matching_results' | |
os.makedirs(saveto, exist_ok=True) | |
image1 = cv2.cvtColor(cv2.imread(args.img1), cv2.COLOR_BGR2RGB) | |
if args.img2 is None: | |
image2 = trans_image_with_homograpy(image1) | |
cv2.imwrite(f'{saveto}/warped_image.png', image2) | |
args.img2 = f'{saveto}/warped_image.png' | |
gray0 = cv2.imread(args.img1, 0) | |
gray1 = cv2.imread(args.img2, 0) | |
torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) | |
torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] | |
x = {'image0': torch_gray0, 'image1': torch_gray1} | |
pred = pipeline_model(x) | |
pred = batch_to_np(pred) | |
kp0, kp1 = pred["keypoints0"], pred["keypoints1"] | |
m0 = pred["matches0"] | |
line_seg0, line_seg1 = pred["lines0"], pred["lines1"] | |
line_matches = pred["line_matches0"] | |
valid_matches = m0 != -1 | |
match_indices = m0[valid_matches] | |
matched_kps0 = kp0[valid_matches] | |
matched_kps1 = kp1[match_indices] | |
valid_matches = line_matches != -1 | |
match_indices = line_matches[valid_matches] | |
matched_lines0 = line_seg0[valid_matches] | |
matched_lines1 = line_seg1[match_indices] | |
# Plot the matches | |
gray0 = cv2.imread(args.img1, 0) | |
gray1 = cv2.imread(args.img2, 0) | |
img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) | |
plot_images([img0, img1], dpi=200, pad=2.0) | |
plot_lines([line_seg0, line_seg1], ps=4, lw=2) | |
plt.gcf().canvas.manager.set_window_title('Detected Lines') | |
# plt.tight_layout() | |
plt.savefig(f'{saveto}/det.png') | |
plot_images([img0, img1], dpi=200, pad=2.0) | |
plot_color_line_matches([matched_lines0, matched_lines1], lw=3) | |
plt.gcf().canvas.manager.set_window_title('Line Matches') | |
# plt.tight_layout() | |
plt.savefig(f'{saveto}/mat.png') | |
whitebg = 1 | |
show.Canvas.white_overlay = whitebg | |
painter = show.painters.HAWPainter() | |
fig_file = f'{saveto}/det1.png' | |
outputs = {'lines_pred': line_seg0.reshape(-1,4)} | |
with show.image_canvas(args.img1, fig_file=fig_file) as ax: | |
# painter.draw_wireframe(ax,outputs, edge_color='orange', vertex_color='Cyan') | |
painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') | |
fig_file = f'{saveto}/det2.png' | |
outputs = {'lines_pred': line_seg1.reshape(-1,4)} | |
with show.image_canvas(args.img2, fig_file=fig_file) as ax: | |
painter.draw_wireframe(ax,outputs, edge_color='midnightblue', vertex_color='deeppink') | |
if __name__ == '__main__': | |
main() | |