Spaces:
Running
Running
File size: 4,903 Bytes
1b369eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""
Modified from hloc
https://github.com/cvg/Hierarchical-Localization.git
"""
import argparse
import collections.abc as collections
import glob
import pprint
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Union
import cv2
import h5py
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
from hloc.extract_features import ImageDataset
from hloc import logger
from hloc.utils.base_model import dynamic_load
from hloc.utils.io import list_h5_names, read_image
from hloc.utils.parsers import parse_image_lists
from RDD.RDD import build
from RDD.utils import read_config
confs = {
'rdd': {
"output": "feats-rdd-n4096",
"model": {
'config_path': './configs/default.yaml',
'weights': './weights/RDD-v2.pth',
},
"preprocessing": {
"grayscale": False,
"resize_max": 1024,
"resize_force": True,
}
}
}
@torch.no_grad()
def main(
conf: Dict,
image_dir: Path,
export_dir: Optional[Path] = None,
as_half: bool = True,
image_list: Optional[Union[Path, List[str]]] = None,
feature_path: Optional[Path] = None,
overwrite: bool = False,
) -> Path:
logger.info(
"Extracting local features with configuration:" f"\n{pprint.pformat(conf)}"
)
dataset = ImageDataset(image_dir, conf["preprocessing"], image_list)
if feature_path is None:
feature_path = Path(export_dir, conf["output"] + ".h5")
feature_path.parent.mkdir(exist_ok=True, parents=True)
skip_names = set(
list_h5_names(feature_path) if feature_path.exists() and not overwrite else ()
)
dataset.names = [n for n in dataset.names if n not in skip_names]
if len(dataset.names) == 0:
logger.info("Skipping the extraction.")
return feature_path
device = "cuda" if torch.cuda.is_available() else "cpu"
config = read_config(conf["model"]["config_path"])
config['device'] = device
model = build(config, conf["model"]["weights"])
model.eval()
loader = torch.utils.data.DataLoader(
dataset, num_workers=1, shuffle=False, pin_memory=True
)
for idx, data in enumerate(tqdm(loader)):
name = dataset.names[idx]
features = model.extract(data["image"])
pred = {
"keypoints": [f["keypoints"] for f in features],
"keypoint_scores": [f["scores"] for f in features],
"descriptors": [f["descriptors"].t() for f in features],
}
pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
pred["image_size"] = original_size = data["original_size"][0].numpy()
if "keypoints" in pred:
size = np.array(data["image"].shape[-2:][::-1])
scales = (original_size / size).astype(np.float32)
pred["keypoints"] = (pred["keypoints"] + 0.5) * scales[None] - 0.5
if "scales" in pred:
pred["scales"] *= scales.mean()
# add keypoint uncertainties scaled to the original resolution
uncertainty = getattr(model, "detection_noise", 1) * scales.mean()
if as_half:
for k in pred:
dt = pred[k].dtype
if (dt == np.float32) and (dt != np.float16):
pred[k] = pred[k].astype(np.float16)
with h5py.File(str(feature_path), "a", libver="latest") as fd:
try:
if name in fd:
del fd[name]
grp = fd.create_group(name)
for k, v in pred.items():
grp.create_dataset(k, data=v)
if "keypoints" in pred:
grp["keypoints"].attrs["uncertainty"] = uncertainty
except OSError as error:
if "No space left on device" in error.args[0]:
logger.error(
"Out of disk space: storing features on disk can take "
"significant space, did you enable the as_half flag?"
)
del grp, fd[name]
raise error
del pred
logger.info("Finished exporting features.")
return feature_path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--image_dir", type=Path, required=True)
parser.add_argument("--export_dir", type=Path, required=True)
parser.add_argument(
"--conf", type=str, default="rdd", choices=list(confs.keys())
)
parser.add_argument("--as_half", action="store_true")
parser.add_argument("--image_list", type=Path)
parser.add_argument("--feature_path", type=Path)
args = parser.parse_args()
main(
confs[args.conf],
args.image_dir,
args.export_dir,
args.as_half,
args.image_list,
args.feature_path,
) |