ScaleLSD / app.py
Nan Xue
update
98dd04e
raw
history blame
8.86 kB
import torch
import cv2
import os
import gradio as gr
import numpy as np
import random
from pathlib import Path
import json
# Title for the Gradio interface
_TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images'
MAX_SEED = 1000
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
"""random seed"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def stop_run():
"""stop run"""
return (
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
)
def process_image(
input_image,
model_name='scalelsd-vitbase-v2-train-sa1b.pt',
save_name='temp_output',
threshold=10,
junction_threshold_hm=0.008,
num_junctions_inference=512,
width=512,
height=512,
line_width=2,
juncs_size=4,
whitebg=0.0,
draw_junctions_only=False,
use_lsd=False,
use_nms=False,
edge_color='orange',
vertex_color='Cyan',
output_format='png',
seed=0,
randomize_seed=False
):
from scalelsd.ssl.models.detector import ScaleLSD
from scalelsd.base import show, WireframeGraph
from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model
"""core processing function for image inference"""
# set random seed
seed = int(randomize_seed_fn(seed, randomize_seed))
fix_seeds(seed)
# initialize model
ckpt = "models/" + model_name
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_scalelsd_model(ckpt, device)
# set model parameters
model.junction_threshold_hm = junction_threshold_hm
model.num_junctions_inference = num_junctions_inference
# transform input image
if isinstance(input_image, np.ndarray):
image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
else:
image = cv2.imread(input_image, 0)
# resize
ori_shape = image.shape[:2]
image_resized = cv2.resize(image.copy(), (width, height))
image_tensor = torch.from_numpy(image_resized).float() / 255.0
image_tensor = image_tensor[None, None].to('cuda')
# meta data
meta = {
'width': ori_shape[1],
'height': ori_shape[0],
'filename': '',
'use_lsd': use_lsd,
'use_nms': use_nms,
}
# inference
with torch.no_grad():
outputs, _ = model(image_tensor, meta)
outputs = outputs[0]
# visual results
painter = show.painters.HAWPainter()
painter.confidence_threshold = threshold
painter.line_width = line_width
painter.marker_size = juncs_size
if whitebg > 0.0:
show.Canvas.white_overlay = whitebg
temp_folder = "temp_output"
os.makedirs(temp_folder, exist_ok=True)
fig_file = f"{temp_folder}/{save_name}.png"
with show.image_canvas(input_image, fig_file=fig_file) as ax:
if draw_junctions_only:
painter.draw_junctions(ax, outputs)
else:
painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
# read the result image
result_image = cv2.imread(fig_file)
if output_format != 'png':
fig_file = f"{temp_folder}/{save_name}.{output_format}"
with show.image_canvas(input_image, fig_file=fig_file) as ax:
if draw_junctions_only:
painter.draw_junctions(ax, outputs)
else:
painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
json_file = f"{temp_folder}/{save_name}.json"
indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
with open(json_file, 'w') as f:
json.dump(wireframe.jsonize(),f)
return result_image[:, :, ::-1], json_file, fig_file
@spaces.GPU()
def run_demo():
"""create the Gradio demo interface"""
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css, title=_TITLE) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f'# {_TITLE}')
gr.Markdown("Detect wireframe structures in images using ScaleLSD model")
pid = gr.State()
figs_root = "assets/figs"
example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)]
with gr.Row():
input_image = gr.Image(example_images[0], label="Input Image", type="numpy")
output_image = gr.Image(label="Detection Result")
with gr.Row():
run_btn = gr.Button(value="Run", variant="primary")
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
with gr.Row():
json_file = gr.File(label="Download JSON Output", type="filepath")
image_file = gr.File(label="Download Image Output", type="filepath")
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
model_name = gr.Dropdown(
[ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')],
value='scalelsd-vitbase-v1-train-sa1b.pt',
label="Model Selection"
)
with gr.Row():
save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files")
with gr.Row():
with gr.Column():
threshold = gr.Number(10, label="Line Threshold")
junction_threshold_hm = gr.Number(0.008, label="Junction Threshold")
num_junctions_inference = gr.Number(1024, label="Max Number of Junctions")
width = gr.Number(512, label="Input Width")
height = gr.Number(512, label="Input Height")
with gr.Column():
draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only")
use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier")
use_nms = gr.Checkbox(True, label="Use NMS")
output_format = gr.Dropdown(
['png', 'jpg', 'pdf'],
value='png',
label="Output Format"
)
whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity")
line_width = gr.Number(2, label="Line Width")
juncs_size = gr.Number(8, label="Junctions Size")
with gr.Row():
edge_color = gr.Dropdown(
['orange', 'midnightblue', 'red', 'green'],
value='orange',
label="Edge Color"
)
vertex_color = gr.Dropdown(
['Cyan', 'deeppink', 'yellow', 'purple'],
value='Cyan',
label="Vertex Color"
)
with gr.Row():
randomize_seed = gr.Checkbox(False, label="Randomize Seed")
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
gr.Examples(
examples=example_images,
inputs=input_image,
)
# star event handlers
run_event = run_btn.click(
fn=process_image,
inputs=[
input_image,
model_name,
save_name,
threshold,
junction_threshold_hm,
num_junctions_inference,
width,
height,
line_width,
juncs_size,
whitebg,
draw_junctions_only,
use_lsd,
use_nms,
edge_color,
vertex_color,
output_format,
seed,
randomize_seed
],
outputs=[output_image, json_file, image_file],
)
# stop event handlers
stop_btn.click(
fn=stop_run,
outputs=[run_btn, stop_btn],
cancels=[run_event],
queue=False,
)
return demo
run_demo().launch()