import torch import random import numpy as np import os import os.path as osp import glob from tqdm import tqdm from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph from scalelsd.ssl.datasets import dataset_util from scalelsd.ssl.models.detector import ScaleLSD from scalelsd.ssl.misc.train_utils import load_scalelsd_model from torch.utils.data import DataLoader import torch.utils.data.dataloader as torch_loader from pathlib import Path import argparse, yaml, logging, time, datetime, cv2, copy, sys, json from easydict import EasyDict import accelerate from accelerate import load_checkpoint_and_dispatch import matplotlib import matplotlib.pyplot as plt def parse_args(): aparser = argparse.ArgumentParser() aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints') aparser.add_argument('-t','--threshold', default=10,type=float) aparser.add_argument('-i', '--img', required=True, type=str) aparser.add_argument('--width', default=512, type=int) aparser.add_argument('--height', default=512,type=int) aparser.add_argument('--whitebg', default=0.0, type=float) aparser.add_argument('--saveto', default=None, type=str,) aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt']) aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps']) aparser.add_argument('--disable-show', default=False, action='store_true') aparser.add_argument('--draw-junctions-only', default=False, action='store_true') aparser.add_argument('--use_lsd', default=False, action='store_true') aparser.add_argument('--use_nms', default=False, action='store_true') ScaleLSD.cli(aparser) args = aparser.parse_args() ScaleLSD.configure(args) return args def main(): args = parse_args() model = load_scalelsd_model(args.ckpt, device=args.device) # Set up output directory and painter if args.saveto is None: print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD') args.saveto = 'temp_output/ScaleLSD' os.makedirs(args.saveto,exist_ok=True) show.painters.HAWPainter.confidence_threshold = args.threshold # show.painters.HAWPainter.line_width = 2 # show.painters.HAWPainter.marker_size = 4 show.Canvas.show = not args.disable_show if args.whitebg > 0.0: show.Canvas.white_overlay = args.whitebg painter = show.painters.HAWPainter() edge_color = 'orange' # 'midnightblue' vertex_color = 'Cyan' # 'deeppink' # Prepare images all_images = [] if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')): all_images.append(args.img) elif os.path.isdir(args.img): for file in os.listdir(args.img): if file.endswith(('.jpg', '.png')): fname = os.path.join(args.img, file) all_images.append(fname) all_images = sorted(all_images) else: raise ValueError('Input must be a file or a directory containing images.') # Inference for fname in tqdm(all_images): pname = Path(fname) image = cv2.imread(fname,0) # for resize input, default shape is [512, 512] ori_shape = image.shape[:2] image_cp = copy.deepcopy(image) image_ = cv2.resize(image_cp, (args.width, args.height)) image_ = torch.from_numpy(image_).float()/255.0 image_ = image_[None,None].to(args.device) meta = { 'width': ori_shape[1], 'height':ori_shape[0], 'filename': '', 'use_lsd': args.use_lsd, 'use_nms': args.use_nms, } with torch.no_grad(): outputs, _ = model(image_, meta) outputs = outputs[0] if args.saveto is not None: if args.ext in ['png', 'pdf']: fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name) with show.image_canvas(fname, fig_file=fig_file) as ax: if args.draw_junctions_only: painter.draw_junctions(ax,outputs) else: # painter.draw_wireframe(ax,outputs) painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color) elif args.ext == 'json': indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred']) wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height']) outpath = osp.join(args.saveto, pname.with_suffix('.json').name) with open(outpath,'w') as f: json.dump(wireframe.jsonize(),f) else: raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext)) if __name__ == "__main__": main()