Spaces:
No application file
No application file
| import pdb | |
| import os | |
| import sys | |
| import tqdm | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from matplotlib import pyplot as pl | |
| pl.ion() | |
| from scipy.ndimage import uniform_filter | |
| smooth = lambda arr: uniform_filter(arr, 3) | |
| def transparent(img, alpha, cmap, **kw): | |
| from matplotlib.colors import Normalize | |
| colored_img = cmap(Normalize(clip=True, **kw)(img)) | |
| colored_img[:, :, -1] = alpha | |
| return colored_img | |
| from tools import common | |
| from tools.dataloader import norm_RGB | |
| from nets.patchnet import * | |
| from extract import NonMaxSuppression | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Visualize the patch detector and descriptor") | |
| parser.add_argument("--img", type=str, default="imgs/brooklyn.png") | |
| parser.add_argument("--resize", type=int, default=512) | |
| parser.add_argument("--out", type=str, default="viz.png") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="network path") | |
| parser.add_argument("--net", type=str, default="", help="network command") | |
| parser.add_argument("--max-kpts", type=int, default=200) | |
| parser.add_argument("--reliability-thr", type=float, default=0.8) | |
| parser.add_argument("--repeatability-thr", type=float, default=0.7) | |
| parser.add_argument( | |
| "--border", type=int, default=20, help="rm keypoints close to border" | |
| ) | |
| parser.add_argument("--gpu", type=int, nargs="+", required=True, help="-1 for CPU") | |
| parser.add_argument("--dbg", type=str, nargs="+", default=(), help="debug options") | |
| args = parser.parse_args() | |
| args.dbg = set(args.dbg) | |
| iscuda = common.torch_set_gpu(args.gpu) | |
| device = torch.device("cuda" if iscuda else "cpu") | |
| # create network | |
| checkpoint = torch.load(args.checkpoint, lambda a, b: a) | |
| args.net = args.net or checkpoint["net"] | |
| print("\n>> Creating net = " + args.net) | |
| net = eval(args.net) | |
| net.load_state_dict( | |
| {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} | |
| ) | |
| if iscuda: | |
| net = net.cuda() | |
| print(f" ( Model size: {common.model_size(net)/1000:.0f}K parameters )") | |
| img = Image.open(args.img).convert("RGB") | |
| if args.resize: | |
| img.thumbnail((args.resize, args.resize)) | |
| img = np.asarray(img) | |
| detector = NonMaxSuppression( | |
| rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr | |
| ) | |
| with torch.no_grad(): | |
| print(">> computing features...") | |
| res = net(imgs=[norm_RGB(img).unsqueeze(0).to(device)]) | |
| rela = res.get("reliability") | |
| repe = res.get("repeatability") | |
| kpts = detector(**res).T[:, [1, 0]] | |
| kpts = kpts[repe[0][0, 0][kpts[:, 1], kpts[:, 0]].argsort()[-args.max_kpts :]] | |
| fig = pl.figure("viz") | |
| kw = dict(cmap=pl.cm.RdYlGn, vmax=1) | |
| crop = (slice(args.border, -args.border or 1),) * 2 | |
| if "reliability" in args.dbg: | |
| ax1 = pl.subplot(131) | |
| pl.imshow(img[crop], cmap=pl.cm.gray) | |
| pl.xticks(()) | |
| pl.yticks(()) | |
| pl.subplot(132) | |
| pl.imshow(img[crop], cmap=pl.cm.gray, alpha=0) | |
| pl.xticks(()) | |
| pl.yticks(()) | |
| x, y = kpts[:, 0:2].cpu().numpy().T - args.border | |
| pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) | |
| ax1 = pl.subplot(133) | |
| rela = rela[0][0, 0].cpu().numpy() | |
| pl.imshow(rela[crop], cmap=pl.cm.RdYlGn, vmax=1, vmin=0.9) | |
| pl.xticks(()) | |
| pl.yticks(()) | |
| else: | |
| ax1 = pl.subplot(131) | |
| pl.imshow(img[crop], cmap=pl.cm.gray) | |
| pl.xticks(()) | |
| pl.yticks(()) | |
| x, y = kpts[:, 0:2].cpu().numpy().T - args.border | |
| pl.plot(x, y, "+", c=(0, 1, 0), ms=10, scalex=0, scaley=0) | |
| pl.subplot(132) | |
| pl.imshow(img[crop], cmap=pl.cm.gray) | |
| pl.xticks(()) | |
| pl.yticks(()) | |
| c = repe[0][0, 0].cpu().numpy() | |
| pl.imshow(transparent(smooth(c)[crop], 0.5, vmin=0, **kw)) | |
| ax1 = pl.subplot(133) | |
| pl.imshow(img[crop], cmap=pl.cm.gray) | |
| pl.xticks(()) | |
| pl.yticks(()) | |
| rela = rela[0][0, 0].cpu().numpy() | |
| pl.imshow(transparent(rela[crop], 0.5, vmin=0.9, **kw)) | |
| pl.gcf().set_size_inches(9, 2.73) | |
| pl.subplots_adjust(0.01, 0.01, 0.99, 0.99, hspace=0.1) | |
| pl.savefig(args.out) | |
| pdb.set_trace() | |