Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,084 Bytes
4c954ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import random
import numpy as np
import os
import os.path as osp
import glob
from tqdm import tqdm
from scalelsd.base import setup_logger, MetricLogger, show, WireframeGraph
from scalelsd.ssl.datasets import dataset_util
from scalelsd.ssl.models.detector import ScaleLSD
from scalelsd.ssl.misc.train_utils import load_scalelsd_model
from torch.utils.data import DataLoader
import torch.utils.data.dataloader as torch_loader
from pathlib import Path
import argparse, yaml, logging, time, datetime, cv2, copy, sys, json
from easydict import EasyDict
import accelerate
from accelerate import load_checkpoint_and_dispatch
import matplotlib
import matplotlib.pyplot as plt
def parse_args():
aparser = argparse.ArgumentParser()
aparser.add_argument('-c', '--ckpt', default='models/scalelsd-vitbase-v1-train-sa1b.pt', type=str, help='the path for loading checkpoints')
aparser.add_argument('-t','--threshold', default=10,type=float)
aparser.add_argument('-i', '--img', required=True, type=str)
aparser.add_argument('--width', default=512, type=int)
aparser.add_argument('--height', default=512,type=int)
aparser.add_argument('--whitebg', default=0.0, type=float)
aparser.add_argument('--saveto', default=None, type=str,)
aparser.add_argument('-e','--ext', default='pdf', type=str, choices=['pdf','png','json','txt'])
aparser.add_argument('--device', default='cuda', type=str, choices=['cuda','cpu','mps'])
aparser.add_argument('--disable-show', default=False, action='store_true')
aparser.add_argument('--draw-junctions-only', default=False, action='store_true')
aparser.add_argument('--use_lsd', default=False, action='store_true')
aparser.add_argument('--use_nms', default=False, action='store_true')
ScaleLSD.cli(aparser)
args = aparser.parse_args()
ScaleLSD.configure(args)
return args
def main():
args = parse_args()
model = load_scalelsd_model(args.ckpt, device=args.device)
# Set up output directory and painter
if args.saveto is None:
print('No output directory specified, saving outputs to folder: temp_output/ScaleLSD')
args.saveto = 'temp_output/ScaleLSD'
os.makedirs(args.saveto,exist_ok=True)
show.painters.HAWPainter.confidence_threshold = args.threshold
# show.painters.HAWPainter.line_width = 2
# show.painters.HAWPainter.marker_size = 4
show.Canvas.show = not args.disable_show
if args.whitebg > 0.0:
show.Canvas.white_overlay = args.whitebg
painter = show.painters.HAWPainter()
edge_color = 'orange' # 'midnightblue'
vertex_color = 'Cyan' # 'deeppink'
# Prepare images
all_images = []
if os.path.isfile(args.img) and args.img.endswith(('.jpg', '.png')):
all_images.append(args.img)
elif os.path.isdir(args.img):
for file in os.listdir(args.img):
if file.endswith(('.jpg', '.png')):
fname = os.path.join(args.img, file)
all_images.append(fname)
all_images = sorted(all_images)
else:
raise ValueError('Input must be a file or a directory containing images.')
# Inference
for fname in tqdm(all_images):
pname = Path(fname)
image = cv2.imread(fname,0)
# for resize input, default shape is [512, 512]
ori_shape = image.shape[:2]
image_cp = copy.deepcopy(image)
image_ = cv2.resize(image_cp, (args.width, args.height))
image_ = torch.from_numpy(image_).float()/255.0
image_ = image_[None,None].to(args.device)
meta = {
'width': ori_shape[1],
'height':ori_shape[0],
'filename': '',
'use_lsd': args.use_lsd,
'use_nms': args.use_nms,
}
with torch.no_grad():
outputs, _ = model(image_, meta)
outputs = outputs[0]
if args.saveto is not None:
if args.ext in ['png', 'pdf']:
fig_file = osp.join(args.saveto, pname.with_suffix('.'+args.ext).name)
with show.image_canvas(fname, fig_file=fig_file) as ax:
if args.draw_junctions_only:
painter.draw_junctions(ax,outputs)
else:
# painter.draw_wireframe(ax,outputs)
painter.draw_wireframe(ax,outputs, edge_color=edge_color, vertex_color=vertex_color)
elif args.ext == 'json':
indices = WireframeGraph.xyxy2indices(outputs['juncs_pred'],outputs['lines_pred'])
wireframe = WireframeGraph(outputs['juncs_pred'], outputs['juncs_score'], indices, outputs['lines_score'], outputs['width'], outputs['height'])
outpath = osp.join(args.saveto, pname.with_suffix('.json').name)
with open(outpath,'w') as f:
json.dump(wireframe.jsonize(),f)
else:
raise ValueError('Unsupported extension: {} is not in [png, pdf, json]'.format(args.ext))
if __name__ == "__main__":
main()
|