import torch import cv2 import os import gradio as gr import numpy as np import random from pathlib import Path import json import spaces # Title for the Gradio interface _TITLE = 'Gradio Demo of ScaleLSD for Structured Representation of Images' MAX_SEED = 1000 os.system('mkdir -p models') os.system('wget https://huggingface.co/cherubicxn/scalelsd/resolve/main/scalelsd-vitbase-v2-train-sa1b.pt -O models/scalelsd-vitbase-v2-train-sa1b.pt') os.system('wget https://huggingface.co/cherubicxn/scalelsd/resolve/main/scalelsd-vitbase-v1-train-sa1b.pt -O models/scalelsd-vitbase-v1-train-sa1b.pt') os.system('pip install -e .') def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: """random seed""" if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def stop_run(): """stop run""" return ( gr.update(value="Run", variant="primary", visible=True), gr.update(visible=False), ) # @spaces.GPU() @spaces.GPU def process_image( input_image, model_name='scalelsd-vitbase-v2-train-sa1b.pt', save_name='temp_output', threshold=10, junction_threshold_hm=0.008, num_junctions_inference=512, width=512, height=512, line_width=2, juncs_size=4, whitebg=0.0, draw_junctions_only=False, use_lsd=False, use_nms=False, edge_color='orange', vertex_color='Cyan', output_format='png', seed=0, randomize_seed=False ): use_lsd = False from scalelsd.ssl.models.detector import ScaleLSD from scalelsd.base import show, WireframeGraph from scalelsd.ssl.misc.train_utils import fix_seeds, load_scalelsd_model """core processing function for image inference""" # set random seed seed = int(randomize_seed_fn(seed, randomize_seed)) fix_seeds(seed) # initialize model ckpt = "models/" + model_name device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = load_scalelsd_model(ckpt, device) # set model parameters model.junction_threshold_hm = junction_threshold_hm model.num_junctions_inference = num_junctions_inference # transform input image if isinstance(input_image, np.ndarray): image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY) else: image = cv2.imread(input_image, 0) # resize ori_shape = image.shape[:2] image_resized = cv2.resize(image.copy(), (width, height)) image_tensor = torch.from_numpy(image_resized).float() / 255.0 image_tensor = image_tensor[None, None].to('cuda') # meta data meta = { 'width': ori_shape[1], 'height': ori_shape[0], 'filename': '', 'use_lsd': use_lsd, 'use_nms': use_nms, } # inference with torch.no_grad(): outputs, _ = model(image_tensor, meta) outputs = outputs[0] # visual results painter = show.painters.HAWPainter() painter.confidence_threshold = threshold painter.line_width = line_width painter.marker_size = juncs_size if whitebg > 0.0: show.Canvas.white_overlay = whitebg temp_folder = "temp_output" os.makedirs(temp_folder, exist_ok=True) fig_file = f"{temp_folder}/{save_name}.png" with show.image_canvas(input_image, fig_file=fig_file) as ax: if draw_junctions_only: painter.draw_junctions(ax, outputs) else: painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) # read the result image result_image = cv2.imread(fig_file) if output_format != 'png': fig_file = f"{temp_folder}/{save_name}.{output_format}" with show.image_canvas(input_image, fig_file=fig_file) as ax: if draw_junctions_only: painter.draw_junctions(ax, outputs) else: painter.draw_wireframe(ax, outputs, edge_color=edge_color, vertex_color=vertex_color) json_file = f"{temp_folder}/{save_name}.json" indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) with open(json_file, 'w') as f: json.dump(wireframe.jsonize(),f) return result_image[:, :, ::-1], json_file, fig_file def run_demo(): """create the Gradio demo interface""" css = """ #col-container { margin: 0 auto; max-width: 800px; } """ with gr.Blocks(css=css, title=_TITLE) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(f'# {_TITLE}') gr.Markdown("Detect wireframe structures in images using ScaleLSD model") pid = gr.State() figs_root = "assets/figs" example_images = [os.path.join(figs_root, iname) for iname in os.listdir(figs_root)] with gr.Row(): input_image = gr.Image(example_images[0], label="Input Image", type="numpy") output_image = gr.Image(label="Detection Result") with gr.Row(): run_btn = gr.Button(value="Run", variant="primary") stop_btn = gr.Button(value="Stop", variant="stop", visible=False) with gr.Row(): json_file = gr.File(label="Download JSON Output", type="filepath") image_file = gr.File(label="Download Image Output", type="filepath") with gr.Accordion("Advanced Settings", open=True): with gr.Row(): model_name = gr.Dropdown( [ckpt for ckpt in os.listdir('models') if ckpt.endswith('.pt')], value='scalelsd-vitbase-v2-train-sa1b.pt', label="Model Selection" ) with gr.Row(): save_name = gr.Textbox('temp_output', label="Save Name", placeholder="Name for saving output files") with gr.Row(): with gr.Column(): threshold = gr.Number(10, label="Line Threshold") junction_threshold_hm = gr.Number(0.008, label="Junction Threshold") num_junctions_inference = gr.Number(1024, label="Max Number of Junctions") width = gr.Number(512, label="Input Width") height = gr.Number(512, label="Input Height") with gr.Column(): draw_junctions_only = gr.Checkbox(False, label="Show Junctions Only") use_lsd = gr.Checkbox(False, label="Use LSD-Rectifier") use_nms = gr.Checkbox(True, label="Use NMS") output_format = gr.Dropdown( ['png', 'jpg', 'pdf'], value='png', label="Output Format" ) whitebg = gr.Slider(0.0, 1.0, value=0.7, label="White Background Opacity") line_width = gr.Number(2, label="Line Width") juncs_size = gr.Number(8, label="Junctions Size") with gr.Row(): edge_color = gr.Dropdown( ['orange', 'midnightblue', 'red', 'green'], value='orange', label="Edge Color" ) vertex_color = gr.Dropdown( ['Cyan', 'deeppink', 'yellow', 'purple'], value='Cyan', label="Vertex Color" ) with gr.Row(): randomize_seed = gr.Checkbox(False, label="Randomize Seed") seed = gr.Slider(0, MAX_SEED, value=42, step=1, label="Seed") gr.Examples( examples=example_images, inputs=input_image, ) # star event handlers run_event = run_btn.click( fn=process_image, inputs=[ input_image, model_name, save_name, threshold, junction_threshold_hm, num_junctions_inference, width, height, line_width, juncs_size, whitebg, draw_junctions_only, use_lsd, use_nms, edge_color, vertex_color, output_format, seed, randomize_seed ], outputs=[output_image, json_file, image_file], ) # stop event handlers stop_btn.click( fn=stop_run, outputs=[run_btn, stop_btn], cancels=[run_event], queue=False, ) return demo run_demo().launch()