Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
from os.path import join | |
import sys | |
import numpy as np | |
import cv2 | |
import torch | |
from matplotlib import pyplot as plt | |
from tqdm import tqdm | |
import gradio as gr | |
import random | |
from gluestick import batch_to_np, numpy_image_to_torch, GLUESTICK_ROOT | |
from gluestick.drawing import plot_images, plot_lines, plot_color_line_matches, plot_keypoints, plot_matches | |
from scalelsd.ssl.models.detector import ScaleLSD | |
from scalelsd.base import show, WireframeGraph | |
from scalelsd.ssl.datasets.transforms.homographic_transforms import sample_homography | |
from scalelsd.ssl.misc.train_utils import fix_seeds | |
from line_matching.two_view_pipeline import TwoViewPipeline | |
from kornia.geometry import warp_perspective,transform_points | |
class HADConfig: | |
num_iter = 1 | |
valid_border_margin = 3 | |
translation = True | |
rotation = True | |
scale = True | |
perspective = True | |
scaling_amplitude = 0.2 | |
perspective_amplitude_x = 0.2 | |
perspective_amplitude_y = 0.2 | |
allow_artifacts = False | |
patch_ratio = 0.85 | |
had_cfg = HADConfig() | |
# Evaluation config | |
default_conf = { | |
'name': 'two_view_pipeline', | |
'use_lines': True, | |
'extractor': { | |
'name': 'wireframe', | |
'sp_params': { | |
'force_num_keypoints': False, | |
'max_num_keypoints': 2048, | |
}, | |
'wireframe_params': { | |
'merge_points': True, | |
'merge_line_endpoints': True, | |
# 'merge_line_endpoints': False, | |
}, | |
'max_n_lines': 512, | |
}, | |
'matcher': { | |
'name': 'gluestick', | |
'weights': str(GLUESTICK_ROOT / 'resources' / 'weights' / 'checkpoint_GlueStick_MD.tar'), | |
'trainable': False, | |
}, | |
'ground_truth': { | |
'from_pose_depth': False, | |
} | |
} | |
# Title for the Gradio interface | |
_TITLE = 'ScaleLSD-GlueStick Line Matching' | |
MAX_SEED = 1000 | |
def sample_homographics(height, width): | |
def scale_homography(H, stride): | |
H_scaled = H.clone() | |
H_scaled[:, :, 2, :2] *= stride | |
H_scaled[:, :, :2, 2] /= stride | |
return H_scaled | |
homographic = sample_homography( | |
shape = (height, width), | |
perspective = had_cfg.perspective, | |
scaling = had_cfg.scale, | |
rotation = had_cfg.rotation, | |
translation = had_cfg.translation, | |
scaling_amplitude = had_cfg.scaling_amplitude, | |
perspective_amplitude_x = had_cfg.perspective_amplitude_x, | |
perspective_amplitude_y = had_cfg.perspective_amplitude_y, | |
patch_ratio = had_cfg.patch_ratio, | |
allow_artifacts = False | |
)[0] | |
homographic = torch.from_numpy(homographic[None]).float().cuda() | |
homographic_inv = torch.inverse(homographic) | |
H = { | |
'h.1': homographic, | |
'ih.1': homographic_inv, | |
} | |
return H | |
def trans_image_with_homograpy(image): | |
h, w = image.shape[:2] | |
H = sample_homographics(height=h, width=w) | |
image_warped = warp_perspective(torch.Tensor(image).permute(2,0,1)[None].cuda(), H['h.1'], (h,w)) | |
image_warped_ = image_warped[0].permute(1,2,0).cpu().numpy().astype(np.uint8) | |
plt.imshow(image_warped_) | |
plt.show() | |
return image_warped_ | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
"""random seed""" | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
def stop_run(): | |
"""stop run""" | |
return ( | |
gr.update(value="Run", variant="primary", visible=True), | |
gr.update(visible=False), | |
) | |
def clear_image2(): | |
return None # returning None will clear the image component | |
def process_image( | |
input_image1='assets/figs/sa_1119229.jpg', | |
input_image2=None, | |
model_name='scalelsd-vitbase-v1-train-sa1b.pt', | |
save_name='temp', | |
threshold=5, | |
junction_threshold_hm=0.008, | |
num_junctions_inference=4096, | |
width=512, | |
height=512, | |
line_width=2, | |
juncs_size=4, | |
whitebg=1.0, | |
draw_junctions_only=False, | |
use_lsd=False, | |
use_nms=False, | |
edge_color='midnightblue', | |
vertex_color='deeppink', | |
output_format='png', | |
seed=0, | |
randomize_seed=False | |
): | |
"""core processing function for image inference""" | |
# set random seed | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
fix_seeds(seed) | |
conf = { | |
'model_name': model_name, | |
'threshold': threshold, | |
'junction_threshold_hm': junction_threshold_hm, | |
'num_junctions_inference': num_junctions_inference, | |
'use_lsd': use_lsd, | |
'use_nms': use_nms, | |
'width': width, | |
'height': height, | |
} | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipeline_model = TwoViewPipeline(default_conf).to(device).eval() | |
pipeline_model.extractor.update_conf(conf) | |
saveto = f'temp_output/matching_results' | |
image1 = cv2.cvtColor(input_image1, cv2.COLOR_BGR2RGB) | |
cv2.imwrite(f'{saveto}/image.png', image1) | |
input_image1 = f'{saveto}/image.png' | |
if input_image2 is None: | |
image2 = trans_image_with_homograpy(image1) | |
else: | |
image2 = cv2.cvtColor(input_image2, cv2.COLOR_BGR2RGB) | |
cv2.imwrite(f'{saveto}/image2.png', image2) | |
input_image2 = f'{saveto}/image2.png' | |
gray0 = cv2.imread(input_image1, 0) | |
gray1 = cv2.imread(input_image2, 0) | |
torch_gray0, torch_gray1 = numpy_image_to_torch(gray0), numpy_image_to_torch(gray1) | |
torch_gray0, torch_gray1 = torch_gray0.to(device)[None], torch_gray1.to(device)[None] | |
x = {'image0': torch_gray0, 'image1': torch_gray1} | |
pred = pipeline_model(x) | |
pred = batch_to_np(pred) | |
kp0, kp1 = pred["keypoints0"], pred["keypoints1"] | |
m0 = pred["matches0"] | |
line_seg0, line_seg1 = pred["lines0"], pred["lines1"] | |
line_matches = pred["line_matches0"] | |
valid_matches = m0 != -1 | |
match_indices = m0[valid_matches] | |
matched_kps0 = kp0[valid_matches] | |
matched_kps1 = kp1[match_indices] | |
valid_matches = line_matches != -1 | |
match_indices = line_matches[valid_matches] | |
matched_lines0 = line_seg0[valid_matches] | |
matched_lines1 = line_seg1[match_indices] | |
img0, img1 = cv2.cvtColor(gray0, cv2.COLOR_GRAY2BGR), cv2.cvtColor(gray1, cv2.COLOR_GRAY2BGR) | |
mat_file = f'{saveto}/{save_name}_mat.png' | |
plot_images([img0, img1], dpi=200, pad=2.0) | |
plot_lines([line_seg0, line_seg1], ps=4, lw=2) | |
plt.gcf().canvas.manager.set_window_title('Detected Lines') | |
# plt.tight_layout() | |
plt.savefig(mat_file) | |
det_image = cv2.imread(mat_file)[:,:,::-1] | |
det_file = f'{saveto}/{save_name}_mat.png' | |
plot_images([img0, img1], dpi=200, pad=2.0) | |
plot_color_line_matches([matched_lines0, matched_lines1], lw=3) | |
plt.gcf().canvas.manager.set_window_title('Line Matches') | |
# plt.tight_layout() | |
plt.savefig(det_file) | |
mat_image = cv2.imread(det_file)[:,:,::-1] | |
show.Canvas.white_overlay = whitebg | |
painter = show.painters.HAWPainter() | |
fig_file = f'{saveto}/{save_name}_det1.png' | |
outputs = {'lines_pred': line_seg0.reshape(-1,4)} | |
with show.image_canvas(input_image1, fig_file=fig_file) as ax: | |
painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) | |
det1_image = cv2.imread(fig_file)[:,:,::-1] | |
fig_file = f'{saveto}/{save_name}_det2.png' | |
outputs = {'lines_pred': line_seg1.reshape(-1,4)} | |
with show.image_canvas(input_image2, fig_file=fig_file) as ax: | |
painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) | |
det2_image = cv2.imread(fig_file)[:,:,::-1] | |
return image2[:,:,::-1], mat_image, det_image, det1_image, det2_image, mat_file, det_file | |
def demo(): | |
"""create the Gradio demo interface""" | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 800px; | |
} | |
""" | |
with gr.Blocks(css=css, title=_TITLE) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f'# {_TITLE}') | |
gr.Markdown("Detect wireframe structures in images using ScaleLSD model") | |
pid = gr.State() | |
figs_root = "assets/mat_figs" | |
example_single = [os.path.join(figs_root, 'single', iname) for iname in os.listdir(figs_root+'/single')] | |
example_pairs = [[img, None] for img in example_single] | |
example_pairs += [ | |
[os.path.join(figs_root, 'pairs', f'ref_{i}.png'), | |
os.path.join(figs_root, 'pairs', f'tgt_{i}.png')] | |
for i in [10, 72, 76, 95, 149, 151] | |
] | |
with gr.Row(): | |
input_image1 = gr.Image(example_pairs[0][0], label="Input Image1", type="numpy") | |
input_image2 = gr.Image(label="Input Image2", type="numpy") | |
with gr.Row(): | |
mat_images = gr.Image(label="Matching Results") | |
with gr.Row(): | |
det_images = gr.Image(label="Detection Results") | |
with gr.Row(): | |
det_image1 = gr.Image(label="Detection1") | |
det_image2 = gr.Image(label="Detection2") | |
with gr.Row(): | |
run_btn = gr.Button(value="Run", variant="primary") | |
stop_btn = gr.Button(value="Stop", variant="stop", visible=False) | |
with gr.Row(): | |
mat_file = gr.File(label="Download Matching Result", type="filepath") | |
det_file = gr.File(label="Download Detection Result", type="filepath") | |
with gr.Accordion("Advanced Settings", open=True): | |
with gr.Row(): | |
model_name = gr.Dropdown( | |
[ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')], | |
value='scalelsd-vitbase-v1-train-sa1b.pt', | |
label="Model Selection" | |
) | |
with gr.Row(): | |
save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files") | |
with gr.Row(): | |
with gr.Column(): | |
threshold = gr.Number(10, label="Line Threshold") | |
junction_threshold_hm = gr.Number(0.008, label="Junction Threshold") | |
num_junctions_inference = gr.Number(1024, label="Max Number of Junctions") | |
width = gr.Number(512, label="Input Width") | |
height = gr.Number(512, label="Input Height") | |
with gr.Column(): | |
draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only") | |
use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier") | |
use_nms = gr.Checkbox(True, label="Use NMS") | |
output_format = gr.Dropdown( | |
['png', 'jpg', 'pdf'], | |
value='png', | |
label="Output Format" | |
) | |
whitebg = gr.Slider(0.0, 1.0, value=1.0, label="White Background Opacity") | |
line_width = gr.Number(2, label="Line Width") | |
juncs_size = gr.Number(8, label="Junctions Size") | |
with gr.Row(): | |
edge_color = gr.Dropdown( | |
['orange', 'midnightblue', 'red', 'green'], | |
value='midnightblue', | |
label="Edge Color" | |
) | |
vertex_color = gr.Dropdown( | |
['Cyan', 'deeppink', 'yellow', 'purple'], | |
value='deeppink', | |
label="Vertex Color" | |
) | |
with gr.Row(): | |
randomize_seed = gr.Checkbox(False, label="Randomize Seed") | |
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") | |
gr.Examples( | |
examples=example_pairs, | |
inputs=[input_image1, input_image2] | |
) | |
# star event handlers | |
run_event = run_btn.click( | |
fn=process_image, | |
inputs=[ | |
input_image1, | |
input_image2, | |
model_name, | |
save_name, | |
threshold, | |
junction_threshold_hm, | |
num_junctions_inference, | |
width, | |
height, | |
line_width, | |
juncs_size, | |
whitebg, | |
draw_junctions_only, | |
use_lsd, | |
use_nms, | |
edge_color, | |
vertex_color, | |
output_format, | |
seed, | |
randomize_seed | |
], | |
outputs=[input_image2, mat_images, det_images, det_image1, det_image2, mat_file, det_file], | |
) | |
# stop event handlers | |
stop_btn.click( | |
fn=stop_run, | |
outputs=[run_btn, stop_btn], | |
cancels=[run_event], | |
queue=False, | |
) | |
# When image1 changes, image2 is cleared | |
input_image1.change( | |
fn=clear_image2, | |
outputs=input_image2 | |
) | |
return demo | |
if __name__ == "__main__": | |
# 启动应用 | |
demo = demo() | |
demo.launch() | |