ScaleLSD / app.py
Nan Xue
update
f5bdb68
import torch
import cv2
import os
import gradio as gr
import numpy as np
import random
from pathlib import Path
import json
import spaces
# Title for the Gradio interface
_TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images'
MAX_SEED = 1000
os.system('mkdir -p models')
os.system('wget https://huggingface.co/cherubicxn/scalelsd/resolve/main/scalelsd-vitbase-v2-train-sa1b.pt -O models/scalelsd-vitbase-v2-train-sa1b.pt')
os.system('wget https://huggingface.co/cherubicxn/scalelsd/resolve/main/scalelsd-vitbase-v1-train-sa1b.pt -O models/scalelsd-vitbase-v1-train-sa1b.pt')
os.system('pip install -e .')
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
"""random seed"""
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
def stop_run():
"""stop run"""
return (
gr.update(value="Run", variant="primary", visible=True),
gr.update(visible=False),
)
# @spaces.GPU()
@spaces.GPU
def process_image(
input_image,
model_name='scalelsd-vitbase-v2-train-sa1b.pt',
save_name='temp_output',
threshold=10,
junction_threshold_hm=0.008,
num_junctions_inference=512,
width=512,
height=512,
line_width=2,
juncs_size=4,
whitebg=0.0,
draw_junctions_only=False,
use_lsd=False,
use_nms=False,
edge_color='orange',
vertex_color='Cyan',
output_format='png',
seed=0,
randomize_seed=False
):
use_lsd = False
from scalelsd.ssl.models.detector import ScaleLSD
from scalelsd.base import show, WireframeGraph
from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model
"""core processing function for image inference"""
# set random seed
seed = int(randomize_seed_fn(seed, randomize_seed))
fix_seeds(seed)
# initialize model
ckpt = "models/" + model_name
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_scalelsd_model(ckpt, device)
# set model parameters
model.junction_threshold_hm = junction_threshold_hm
model.num_junctions_inference = num_junctions_inference
# transform input image
if isinstance(input_image, np.ndarray):
image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
else:
image = cv2.imread(input_image, 0)
# resize
ori_shape = image.shape[:2]
image_resized = cv2.resize(image.copy(), (width, height))
image_tensor = torch.from_numpy(image_resized).float() / 255.0
image_tensor = image_tensor[None, None].to('cuda')
# meta data
meta = {
'width': ori_shape[1],
'height': ori_shape[0],
'filename': '',
'use_lsd': use_lsd,
'use_nms': use_nms,
}
# inference
with torch.no_grad():
outputs, _ = model(image_tensor, meta)
outputs = outputs[0]
# visual results
painter = show.painters.HAWPainter()
painter.confidence_threshold = threshold
painter.line_width = line_width
painter.marker_size = juncs_size
if whitebg > 0.0:
show.Canvas.white_overlay = whitebg
temp_folder = "temp_output"
os.makedirs(temp_folder, exist_ok=True)
fig_file = f"{temp_folder}/{save_name}.png"
with show.image_canvas(input_image, fig_file=fig_file) as ax:
if draw_junctions_only:
painter.draw_junctions(ax, outputs)
else:
painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
# read the result image
result_image = cv2.imread(fig_file)
if output_format != 'png':
fig_file = f"{temp_folder}/{save_name}.{output_format}"
with show.image_canvas(input_image, fig_file=fig_file) as ax:
if draw_junctions_only:
painter.draw_junctions(ax, outputs)
else:
painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color)
json_file = f"{temp_folder}/{save_name}.json"
indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
with open(json_file, 'w') as f:
json.dump(wireframe.jsonize(),f)
return result_image[:, :, ::-1], json_file, fig_file
def run_demo():
"""create the Gradio demo interface"""
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css, title=_TITLE) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f'# {_TITLE}')
gr.Markdown("Detect wireframe structures in images using ScaleLSD model")
pid = gr.State()
figs_root = "assets/figs"
example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)]
with gr.Row():
input_image = gr.Image(example_images[0], label="Input Image", type="numpy")
output_image = gr.Image(label="Detection Result")
with gr.Row():
run_btn = gr.Button(value="Run", variant="primary")
stop_btn = gr.Button(value="Stop", variant="stop", visible=False)
with gr.Row():
json_file = gr.File(label="Download JSON Output", type="filepath")
image_file = gr.File(label="Download Image Output", type="filepath")
with gr.Accordion("Advanced Settings", open=True):
with gr.Row():
model_name = gr.Dropdown(
[ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')],
value='scalelsd-vitbase-v2-train-sa1b.pt',
label="Model Selection"
)
with gr.Row():
save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files")
with gr.Row():
with gr.Column():
threshold = gr.Number(10, label="Line Threshold")
junction_threshold_hm = gr.Number(0.008, label="Junction Threshold")
num_junctions_inference = gr.Number(1024, label="Max Number of Junctions")
width = gr.Number(512, label="Input Width")
height = gr.Number(512, label="Input Height")
with gr.Column():
draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only")
use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier")
use_nms = gr.Checkbox(True, label="Use NMS")
output_format = gr.Dropdown(
['png', 'jpg', 'pdf'],
value='png',
label="Output Format"
)
whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity")
line_width = gr.Number(2, label="Line Width")
juncs_size = gr.Number(8, label="Junctions Size")
with gr.Row():
edge_color = gr.Dropdown(
['orange', 'midnightblue', 'red', 'green'],
value='orange',
label="Edge Color"
)
vertex_color = gr.Dropdown(
['Cyan', 'deeppink', 'yellow', 'purple'],
value='Cyan',
label="Vertex Color"
)
with gr.Row():
randomize_seed = gr.Checkbox(False, label="Randomize Seed")
seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed")
gr.Examples(
examples=example_images,
inputs=input_image,
)
# star event handlers
run_event = run_btn.click(
fn=process_image,
inputs=[
input_image,
model_name,
save_name,
threshold,
junction_threshold_hm,
num_junctions_inference,
width,
height,
line_width,
juncs_size,
whitebg,
draw_junctions_only,
use_lsd,
use_nms,
edge_color,
vertex_color,
output_format,
seed,
randomize_seed
],
outputs=[output_image, json_file, image_file],
)
# stop event handlers
stop_btn.click(
fn=stop_run,
outputs=[run_btn, stop_btn],
cancels=[run_event],
queue=False,
)
return demo
run_demo().launch()