Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import random | |
import numpy as np | |
import os | |
import os.path as osp | |
import glob | |
from tqdm import tqdm | |
from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph | |
from scalelsd.ssl.datasets import dataset_util | |
from scalelsd.ssl.models.detector import ScaleLSD | |
from scalelsd.ssl.misc.train_utils import load_scalelsd_model | |
from torch.utils.data import DataLoader | |
import torch.utils.data.dataloader as torch_loader | |
from pathlib import Path | |
import argparse, yaml, logging, time, datetime, cv2, copy, sys, json | |
from easydict import EasyDict | |
import accelerate | |
from accelerate import load_checkpoint_and_dispatch | |
import matplotlib | |
import matplotlib.pyplot as plt | |
def parse_args(): | |
aparser = argparse.ArgumentParser() | |
aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints') | |
aparser.add_argument('-t','--threshold', default=10,type=float) | |
aparser.add_argument('-i', '--img', required=True, type=str) | |
aparser.add_argument('--width', default=512, type=int) | |
aparser.add_argument('--height', default=512,type=int) | |
aparser.add_argument('--whitebg', default=0.0, type=float) | |
aparser.add_argument('--saveto', default=None, type=str,) | |
aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt']) | |
aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps']) | |
aparser.add_argument('--disable-show', default=False, action='store_true') | |
aparser.add_argument('--draw-junctions-only', default=False, action='store_true') | |
aparser.add_argument('--use_lsd', default=False, action='store_true') | |
aparser.add_argument('--use_nms', default=False, action='store_true') | |
ScaleLSD.cli(aparser) | |
args = aparser.parse_args() | |
ScaleLSD.configure(args) | |
return args | |
def main(): | |
args = parse_args() | |
model = load_scalelsd_model(args.ckpt, device=args.device) | |
# Set up output directory and painter | |
if args.saveto is None: | |
print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD') | |
args.saveto = 'temp_output/ScaleLSD' | |
os.makedirs(args.saveto,exist_ok=True) | |
show.painters.HAWPainter.confidence_threshold = args.threshold | |
# show.painters.HAWPainter.line_width = 2 | |
# show.painters.HAWPainter.marker_size = 4 | |
show.Canvas.show = not args.disable_show | |
if args.whitebg > 0.0: | |
show.Canvas.white_overlay = args.whitebg | |
painter = show.painters.HAWPainter() | |
edge_color = 'orange' # 'midnightblue' | |
vertex_color = 'Cyan' # 'deeppink' | |
# Prepare images | |
all_images = [] | |
if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')): | |
all_images.append(args.img) | |
elif os.path.isdir(args.img): | |
for file in os.listdir(args.img): | |
if file.endswith(('.jpg', '.png')): | |
fname = os.path.join(args.img, file) | |
all_images.append(fname) | |
all_images = sorted(all_images) | |
else: | |
raise ValueError('Input must be a file or a directory containing images.') | |
# Inference | |
for fname in tqdm(all_images): | |
pname = Path(fname) | |
image = cv2.imread(fname,0) | |
# for resize input, default shape is [512, 512] | |
ori_shape = image.shape[:2] | |
image_cp = copy.deepcopy(image) | |
image_ = cv2.resize(image_cp, (args.width, args.height)) | |
image_ = torch.from_numpy(image_).float()/255.0 | |
image_ = image_[None,None].to(args.device) | |
meta = { | |
'width': ori_shape[1], | |
'height':ori_shape[0], | |
'filename': '', | |
'use_lsd': args.use_lsd, | |
'use_nms': args.use_nms, | |
} | |
with torch.no_grad(): | |
outputs, _ = model(image_, meta) | |
outputs = outputs[0] | |
if args.saveto is not None: | |
if args.ext in ['png', 'pdf']: | |
fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name) | |
with show.image_canvas(fname, fig_file=fig_file) as ax: | |
if args.draw_junctions_only: | |
painter.draw_junctions(ax,outputs) | |
else: | |
# painter.draw_wireframe(ax,outputs) | |
painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) | |
elif args.ext == 'json': | |
indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) | |
wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) | |
outpath = osp.join(args.saveto, pname.with_suffix('.json').name) | |
with open(outpath,'w') as f: | |
json.dump(wireframe.jsonize(),f) | |
else: | |
raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext)) | |
if __name__ == "__main__": | |
main() | |