ScaleLSD / predictor /predict.py
Nan Xue
update
4c954ae
import torch
import random
import numpy as np
import os
import os.path as osp
import glob
from tqdm import tqdm
from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph
from scalelsd.ssl.datasets import dataset_util
from scalelsd.ssl.models.detector import ScaleLSD
from scalelsd.ssl.misc.train_utils import load_scalelsd_model
from torch.utils.data import DataLoader
import torch.utils.data.dataloader as torch_loader
from pathlib import Path
import argparse, yaml, logging, time, datetime, cv2, copy, sys, json
from easydict import EasyDict
import accelerate
from accelerate import load_checkpoint_and_dispatch
import matplotlib
import matplotlib.pyplot as plt
def parse_args():
aparser = argparse.ArgumentParser()
aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints')
aparser.add_argument('-t','--threshold', default=10,type=float)
aparser.add_argument('-i', '--img', required=True, type=str)
aparser.add_argument('--width', default=512, type=int)
aparser.add_argument('--height', default=512,type=int)
aparser.add_argument('--whitebg', default=0.0, type=float)
aparser.add_argument('--saveto', default=None, type=str,)
aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt'])
aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps'])
aparser.add_argument('--disable-show', default=False, action='store_true')
aparser.add_argument('--draw-junctions-only', default=False, action='store_true')
aparser.add_argument('--use_lsd', default=False, action='store_true')
aparser.add_argument('--use_nms', default=False, action='store_true')
ScaleLSD.cli(aparser)
args = aparser.parse_args()
ScaleLSD.configure(args)
return args
def main():
args = parse_args()
model = load_scalelsd_model(args.ckpt, device=args.device)
# Set up output directory and painter
if args.saveto is None:
print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD')
args.saveto = 'temp_output/ScaleLSD'
os.makedirs(args.saveto,exist_ok=True)
show.painters.HAWPainter.confidence_threshold = args.threshold
# show.painters.HAWPainter.line_width = 2
# show.painters.HAWPainter.marker_size = 4
show.Canvas.show = not args.disable_show
if args.whitebg > 0.0:
show.Canvas.white_overlay = args.whitebg
painter = show.painters.HAWPainter()
edge_color = 'orange' # 'midnightblue'
vertex_color = 'Cyan' # 'deeppink'
# Prepare images
all_images = []
if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')):
all_images.append(args.img)
elif os.path.isdir(args.img):
for file in os.listdir(args.img):
if file.endswith(('.jpg', '.png')):
fname = os.path.join(args.img, file)
all_images.append(fname)
all_images = sorted(all_images)
else:
raise ValueError('Input must be a file or a directory containing images.')
# Inference
for fname in tqdm(all_images):
pname = Path(fname)
image = cv2.imread(fname,0)
# for resize input, default shape is [512, 512]
ori_shape = image.shape[:2]
image_cp = copy.deepcopy(image)
image_ = cv2.resize(image_cp, (args.width, args.height))
image_ = torch.from_numpy(image_).float()/255.0
image_ = image_[None,None].to(args.device)
meta = {
'width': ori_shape[1],
'height':ori_shape[0],
'filename': '',
'use_lsd': args.use_lsd,
'use_nms': args.use_nms,
}
with torch.no_grad():
outputs, _ = model(image_, meta)
outputs = outputs[0]
if args.saveto is not None:
if args.ext in ['png', 'pdf']:
fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name)
with show.image_canvas(fname, fig_file=fig_file) as ax:
if args.draw_junctions_only:
painter.draw_junctions(ax,outputs)
else:
# painter.draw_wireframe(ax,outputs)
painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
elif args.ext == 'json':
indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
outpath = osp.join(args.saveto, pname.with_suffix('.json').name)
with open(outpath,'w') as f:
json.dump(wireframe.jsonize(),f)
else:
raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext))
if __name__ == "__main__":
main()