Spaces:
Running
Running
add: rdd sparse and dense match
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -1
- README.md +2 -1
- config/config.yaml +21 -0
- imcui/hloc/extract_features.py +11 -0
- imcui/hloc/extractors/liftfeat.py +3 -8
- imcui/hloc/extractors/rdd.py +56 -0
- imcui/hloc/match_dense.py +17 -0
- imcui/hloc/matchers/rdd_dense.py +52 -0
- imcui/third_party/rdd/.gitignore +8 -0
- imcui/third_party/rdd/LICENSE +201 -0
- imcui/third_party/rdd/RDD/RDD.py +260 -0
- imcui/third_party/rdd/RDD/RDD_helper.py +179 -0
- imcui/third_party/rdd/RDD/dataset/__init__.py +0 -0
- imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py +2 -0
- imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py +313 -0
- imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py +75 -0
- imcui/third_party/rdd/RDD/dataset/megadepth/utils.py +848 -0
- imcui/third_party/rdd/RDD/matchers/__init__.py +3 -0
- imcui/third_party/rdd/RDD/matchers/dense_matcher.py +88 -0
- imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py +31 -0
- imcui/third_party/rdd/RDD/matchers/lightglue.py +667 -0
- imcui/third_party/rdd/RDD/models/backbone.py +147 -0
- imcui/third_party/rdd/RDD/models/deformable_transformer.py +270 -0
- imcui/third_party/rdd/RDD/models/descriptor.py +116 -0
- imcui/third_party/rdd/RDD/models/detector.py +141 -0
- imcui/third_party/rdd/RDD/models/interpolator.py +33 -0
- imcui/third_party/rdd/RDD/models/ops/functions/__init__.py +13 -0
- imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py +72 -0
- imcui/third_party/rdd/RDD/models/ops/make.sh +13 -0
- imcui/third_party/rdd/RDD/models/ops/modules/__init__.py +12 -0
- imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py +125 -0
- imcui/third_party/rdd/RDD/models/ops/setup.py +78 -0
- imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp +46 -0
- imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h +38 -0
- imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu +158 -0
- imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h +35 -0
- imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh +1332 -0
- imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h +67 -0
- imcui/third_party/rdd/RDD/models/ops/src/vision.cpp +21 -0
- imcui/third_party/rdd/RDD/models/ops/test.py +92 -0
- imcui/third_party/rdd/RDD/models/position_encoding.py +48 -0
- imcui/third_party/rdd/RDD/models/soft_detect.py +250 -0
- imcui/third_party/rdd/RDD/utils/__init__.py +1 -0
- imcui/third_party/rdd/RDD/utils/misc.py +531 -0
- imcui/third_party/rdd/README.md +197 -0
- imcui/third_party/rdd/benchmarks/air_ground.py +247 -0
- imcui/third_party/rdd/benchmarks/mega_1500.py +255 -0
- imcui/third_party/rdd/benchmarks/mega_view.py +250 -0
- imcui/third_party/rdd/benchmarks/utils.py +112 -0
- imcui/third_party/rdd/configs/default.yaml +19 -0
.gitignore
CHANGED
@@ -7,7 +7,7 @@ cmake-build-debug/
|
|
7 |
*.pyc
|
8 |
flagged
|
9 |
.ipynb_checkpoints
|
10 |
-
__pycache__
|
11 |
Untitled*
|
12 |
experiments
|
13 |
third_party/REKD
|
|
|
7 |
*.pyc
|
8 |
flagged
|
9 |
.ipynb_checkpoints
|
10 |
+
**__pycache__**
|
11 |
Untitled*
|
12 |
experiments
|
13 |
third_party/REKD
|
README.md
CHANGED
@@ -44,8 +44,9 @@ The tool currently supports various popular image matching algorithms, namely:
|
|
44 |
|
45 |
| Algorithm | Supported | Conference/Journal | Year | GitHub Link |
|
46 |
|------------------|-----------|--------------------|------|-------------|
|
47 |
-
| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
|
48 |
| LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
|
|
|
|
|
49 |
| MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
|
50 |
| XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
|
51 |
| EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
|
|
|
44 |
|
45 |
| Algorithm | Supported | Conference/Journal | Year | GitHub Link |
|
46 |
|------------------|-----------|--------------------|------|-------------|
|
|
|
47 |
| LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
|
48 |
+
| RDD | ✅ | CVPR | 2025 | [Link](https://github.com/xtcpete/rdd) |
|
49 |
+
| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
|
50 |
| MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
|
51 |
| XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
|
52 |
| EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
|
config/config.yaml
CHANGED
@@ -267,6 +267,27 @@ matcher_zoo:
|
|
267 |
paper: https://arxiv.org/abs/2505.0342
|
268 |
project: null
|
269 |
display: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
dedode:
|
271 |
matcher: Dual-Softmax
|
272 |
feature: dedode
|
|
|
267 |
paper: https://arxiv.org/abs/2505.0342
|
268 |
project: null
|
269 |
display: true
|
270 |
+
rdd(sparse):
|
271 |
+
matcher: NN-mutual
|
272 |
+
feature: rdd
|
273 |
+
dense: false
|
274 |
+
info:
|
275 |
+
name: RDD(sparse) #dispaly name
|
276 |
+
source: "CVPR 2025"
|
277 |
+
github: hhttps://github.com/xtcpete/rdd
|
278 |
+
paper: https://arxiv.org/abs/2505.08013
|
279 |
+
project: https://xtcpete.github.io/rdd
|
280 |
+
display: true
|
281 |
+
rdd(dense):
|
282 |
+
matcher: rdd_dense
|
283 |
+
dense: true
|
284 |
+
info:
|
285 |
+
name: RDD(dense) #dispaly name
|
286 |
+
source: "CVPR 2025"
|
287 |
+
github: hhttps://github.com/xtcpete/rdd
|
288 |
+
paper: https://arxiv.org/abs/2505.08013
|
289 |
+
project: https://xtcpete.github.io/rdd
|
290 |
+
display: true
|
291 |
dedode:
|
292 |
matcher: Dual-Softmax
|
293 |
feature: dedode
|
imcui/hloc/extract_features.py
CHANGED
@@ -225,6 +225,17 @@ confs = {
|
|
225 |
"resize_max": 1600,
|
226 |
},
|
227 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
"aliked-n16-rot": {
|
229 |
"output": "feats-aliked-n16-rot",
|
230 |
"model": {
|
|
|
225 |
"resize_max": 1600,
|
226 |
},
|
227 |
},
|
228 |
+
"rdd": {
|
229 |
+
"output": "feats-rdd-n5000-r1600",
|
230 |
+
"model": {
|
231 |
+
"name": "rdd",
|
232 |
+
"max_keypoints": 5000,
|
233 |
+
},
|
234 |
+
"preprocessing": {
|
235 |
+
"grayscale": False,
|
236 |
+
"resize_max": 1600,
|
237 |
+
},
|
238 |
+
},
|
239 |
"aliked-n16-rot": {
|
240 |
"output": "feats-aliked-n16-rot",
|
241 |
"model": {
|
imcui/hloc/extractors/liftfeat.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
-
import logging
|
2 |
import sys
|
3 |
from pathlib import Path
|
4 |
-
import torch
|
5 |
-
import random
|
6 |
from ..utils.base_model import BaseModel
|
7 |
from .. import logger, MODEL_REPO_ID
|
8 |
|
9 |
-
|
10 |
-
sys.path.append(str(
|
11 |
|
12 |
from models.liftfeat_wrapper import LiftFeat
|
13 |
|
@@ -25,9 +22,7 @@ class Liftfeat(BaseModel):
|
|
25 |
logger.info("Loading LiftFeat model...")
|
26 |
model_path = self._download_model(
|
27 |
repo_id=MODEL_REPO_ID,
|
28 |
-
filename="{}/{}".format(
|
29 |
-
Path(__file__).stem, self.conf["model_name"]
|
30 |
-
),
|
31 |
)
|
32 |
self.net = LiftFeat(
|
33 |
weight=model_path,
|
|
|
|
|
1 |
import sys
|
2 |
from pathlib import Path
|
|
|
|
|
3 |
from ..utils.base_model import BaseModel
|
4 |
from .. import logger, MODEL_REPO_ID
|
5 |
|
6 |
+
liftfeat_path = Path(__file__).parent / "../../third_party/LiftFeat"
|
7 |
+
sys.path.append(str(liftfeat_path))
|
8 |
|
9 |
from models.liftfeat_wrapper import LiftFeat
|
10 |
|
|
|
22 |
logger.info("Loading LiftFeat model...")
|
23 |
model_path = self._download_model(
|
24 |
repo_id=MODEL_REPO_ID,
|
25 |
+
filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
|
|
|
|
|
26 |
)
|
27 |
self.net = LiftFeat(
|
28 |
weight=model_path,
|
imcui/hloc/extractors/rdd.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import yaml
|
3 |
+
from pathlib import Path
|
4 |
+
from ..utils.base_model import BaseModel
|
5 |
+
from .. import logger, MODEL_REPO_ID, DEVICE
|
6 |
+
|
7 |
+
rdd_path = Path(__file__).parent / "../../third_party/rdd"
|
8 |
+
sys.path.append(str(rdd_path))
|
9 |
+
|
10 |
+
from RDD.RDD import build as build_rdd
|
11 |
+
|
12 |
+
class Rdd(BaseModel):
|
13 |
+
default_conf = {
|
14 |
+
"keypoint_threshold": 0.1,
|
15 |
+
"max_keypoints": 4096,
|
16 |
+
"model_name": "RDD-v2.pth",
|
17 |
+
}
|
18 |
+
|
19 |
+
required_inputs = ["image"]
|
20 |
+
|
21 |
+
def _init(self, conf):
|
22 |
+
logger.info("Loading RDD model...")
|
23 |
+
model_path = self._download_model(
|
24 |
+
repo_id=MODEL_REPO_ID,
|
25 |
+
filename="{}/{}".format(
|
26 |
+
Path(__file__).stem, self.conf["model_name"]
|
27 |
+
),
|
28 |
+
)
|
29 |
+
config_path = rdd_path / "configs/default.yaml"
|
30 |
+
with open(config_path, "r") as file:
|
31 |
+
config = yaml.safe_load(file)
|
32 |
+
config["top_k"] = conf["max_keypoints"]
|
33 |
+
config["detection_threshold"] = conf["keypoint_threshold"]
|
34 |
+
config["device"] = DEVICE
|
35 |
+
self.net = build_rdd(config=config, weights=model_path)
|
36 |
+
self.net.eval()
|
37 |
+
logger.info("Loading RDD model done!")
|
38 |
+
|
39 |
+
def _forward(self, data):
|
40 |
+
image = data["image"]
|
41 |
+
pred = self.net.extract(image)[0]
|
42 |
+
keypoints = pred["keypoints"]
|
43 |
+
descriptors = pred["descriptors"]
|
44 |
+
scores = pred["scores"]
|
45 |
+
if self.conf["max_keypoints"] < len(keypoints):
|
46 |
+
idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
|
47 |
+
keypoints = keypoints[idxs, :2]
|
48 |
+
descriptors = descriptors[idxs]
|
49 |
+
scores = scores[idxs]
|
50 |
+
|
51 |
+
pred = {
|
52 |
+
"keypoints": keypoints[None],
|
53 |
+
"descriptors": descriptors[None].permute(0, 2, 1),
|
54 |
+
"scores": scores[None],
|
55 |
+
}
|
56 |
+
return pred
|
imcui/hloc/match_dense.py
CHANGED
@@ -337,6 +337,23 @@ confs = {
|
|
337 |
"dfactor": 8,
|
338 |
},
|
339 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
"minima_roma": {
|
341 |
"output": "matches-minima_roma",
|
342 |
"model": {
|
|
|
337 |
"dfactor": 8,
|
338 |
},
|
339 |
},
|
340 |
+
"rdd_dense": {
|
341 |
+
"output": "matches-rdd_dense",
|
342 |
+
"model": {
|
343 |
+
"name": "rdd_dense",
|
344 |
+
"model_name": "RDD-v2.pth",
|
345 |
+
"max_keypoints": 2000,
|
346 |
+
"match_threshold": 0.2,
|
347 |
+
},
|
348 |
+
"preprocessing": {
|
349 |
+
"grayscale": False,
|
350 |
+
"force_resize": True,
|
351 |
+
"resize_max": 1024,
|
352 |
+
"width": 320,
|
353 |
+
"height": 240,
|
354 |
+
"dfactor": 8,
|
355 |
+
},
|
356 |
+
},
|
357 |
"minima_roma": {
|
358 |
"output": "matches-minima_roma",
|
359 |
"model": {
|
imcui/hloc/matchers/rdd_dense.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
from pathlib import Path
|
5 |
+
from ..utils.base_model import BaseModel
|
6 |
+
from .. import logger, MODEL_REPO_ID, DEVICE
|
7 |
+
|
8 |
+
rdd_path = Path(__file__).parent / "../../third_party/rdd"
|
9 |
+
sys.path.append(str(rdd_path))
|
10 |
+
|
11 |
+
from RDD.RDD import build as build_rdd
|
12 |
+
from RDD.RDD_helper import RDD_helper
|
13 |
+
|
14 |
+
class RddDense(BaseModel):
|
15 |
+
default_conf = {
|
16 |
+
"keypoint_threshold": 0.1,
|
17 |
+
"max_keypoints": 4096,
|
18 |
+
"model_name": "RDD-v2.pth",
|
19 |
+
"match_threshold": 0.1,
|
20 |
+
}
|
21 |
+
|
22 |
+
required_inputs = ["image0", "image1"]
|
23 |
+
|
24 |
+
def _init(self, conf):
|
25 |
+
logger.info("Loading RDD model...")
|
26 |
+
model_path = self._download_model(
|
27 |
+
repo_id=MODEL_REPO_ID,
|
28 |
+
filename="{}/{}".format(
|
29 |
+
"rdd", self.conf["model_name"]
|
30 |
+
),
|
31 |
+
)
|
32 |
+
config_path = rdd_path / "configs/default.yaml"
|
33 |
+
with open(config_path, "r") as file:
|
34 |
+
config = yaml.safe_load(file)
|
35 |
+
config["top_k"] = conf["max_keypoints"]
|
36 |
+
config["detection_threshold"] = conf["keypoint_threshold"]
|
37 |
+
config["device"] = DEVICE
|
38 |
+
rdd_net = build_rdd(config=config, weights=model_path)
|
39 |
+
rdd_net.eval()
|
40 |
+
self.net = RDD_helper(rdd_net)
|
41 |
+
logger.info("Loading RDD model done!")
|
42 |
+
|
43 |
+
def _forward(self, data):
|
44 |
+
img0 = data["image0"]
|
45 |
+
img1 = data["image1"]
|
46 |
+
mkpts_0, mkpts_1, conf = self.net.match_dense(img0, img1, thr=self.conf["match_threshold"])
|
47 |
+
pred = {
|
48 |
+
"keypoints0": torch.from_numpy(mkpts_0),
|
49 |
+
"keypoints1": torch.from_numpy(mkpts_1),
|
50 |
+
"mconf": torch.from_numpy(conf),
|
51 |
+
}
|
52 |
+
return pred
|
imcui/third_party/rdd/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
/build/
|
3 |
+
**.egg-info
|
4 |
+
**.pyc
|
5 |
+
/.idea/
|
6 |
+
**/__pycache__/
|
7 |
+
weights/
|
8 |
+
outputs
|
imcui/third_party/rdd/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
imcui/third_party/rdd/RDD/RDD.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Description: RDD model
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
import numpy as np
|
6 |
+
from .utils import NestedTensor, nested_tensor_from_tensor_list, to_pixel_coords, read_config
|
7 |
+
from .models.detector import build_detector
|
8 |
+
from .models.descriptor import build_descriptor
|
9 |
+
from .models.soft_detect import SoftDetect
|
10 |
+
from .models.interpolator import InterpolateSparse2d
|
11 |
+
|
12 |
+
class RDD(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, detector, descriptor, detection_threshold=0.5, top_k=4096, train_detector=False, device='cuda'):
|
15 |
+
super().__init__()
|
16 |
+
self.detector = detector
|
17 |
+
self.descriptor = descriptor
|
18 |
+
self.interpolator = InterpolateSparse2d('bicubic')
|
19 |
+
self.detection_threshold = detection_threshold
|
20 |
+
self.top_k = top_k
|
21 |
+
self.device = device
|
22 |
+
if train_detector:
|
23 |
+
for p in self.detector.parameters():
|
24 |
+
p.requires_grad = True
|
25 |
+
for p in self.descriptor.parameters():
|
26 |
+
p.requires_grad = False
|
27 |
+
else:
|
28 |
+
for p in self.detector.parameters():
|
29 |
+
p.requires_grad = False
|
30 |
+
for p in self.descriptor.parameters():
|
31 |
+
p.requires_grad = True
|
32 |
+
|
33 |
+
self.softdetect = None
|
34 |
+
self.stride = descriptor.stride
|
35 |
+
|
36 |
+
def train(self, mode=True):
|
37 |
+
super().train(mode)
|
38 |
+
self.set_softdetect(top_k=500, scores_th=0.2)
|
39 |
+
|
40 |
+
def eval(self):
|
41 |
+
super().eval()
|
42 |
+
self.set_softdetect(top_k=self.top_k, scores_th=0.01)
|
43 |
+
|
44 |
+
def forward(self, samples: NestedTensor):
|
45 |
+
|
46 |
+
if not isinstance(samples, NestedTensor):
|
47 |
+
samples = nested_tensor_from_tensor_list(samples)
|
48 |
+
|
49 |
+
scoremap = self.detector(samples)
|
50 |
+
|
51 |
+
feats, matchibility = self.descriptor(samples)
|
52 |
+
|
53 |
+
return feats, scoremap, matchibility
|
54 |
+
|
55 |
+
def set_softdetect(self, top_k=4096, scores_th=0.01):
|
56 |
+
self.softdetect = SoftDetect(radius=2, top_k=top_k, scores_th=scores_th)
|
57 |
+
|
58 |
+
@torch.inference_mode()
|
59 |
+
def filter(self, matchibility):
|
60 |
+
# Filter out keypoints on the border
|
61 |
+
B, _, H, W = matchibility.shape
|
62 |
+
frame = torch.zeros(B, H, W, device=matchibility.device)
|
63 |
+
frame[:, self.stride:-self.stride, self.stride:-self.stride] = 1
|
64 |
+
matchibility = matchibility * frame
|
65 |
+
return matchibility
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def extract(self, x):
|
69 |
+
if self.softdetect is None:
|
70 |
+
self.eval()
|
71 |
+
|
72 |
+
x, rh1, rw1 = self.preprocess_tensor(x)
|
73 |
+
x = x.to(self.device).float()
|
74 |
+
B, _, _H1, _W1 = x.shape
|
75 |
+
M1, K1, H1 = self.forward(x)
|
76 |
+
M1 = F.normalize(M1, dim=1)
|
77 |
+
|
78 |
+
keypoints, kptscores, scoredispersitys = self.softdetect(K1)
|
79 |
+
|
80 |
+
keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)])
|
81 |
+
kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)])
|
82 |
+
|
83 |
+
keypoints = to_pixel_coords(keypoints, _H1, _W1)
|
84 |
+
|
85 |
+
feats = self.interpolator(M1, keypoints, H = _H1, W = _W1)
|
86 |
+
|
87 |
+
feats = F.normalize(feats, dim=-1)
|
88 |
+
|
89 |
+
# Correct kpt scale
|
90 |
+
keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1)
|
91 |
+
valid = kptscores > self.detection_threshold
|
92 |
+
|
93 |
+
return [
|
94 |
+
{'keypoints': keypoints[b][valid[b]],
|
95 |
+
'scores': kptscores[b][valid[b]],
|
96 |
+
'descriptors': feats[b][valid[b]]} for b in range(B)
|
97 |
+
]
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def extract_3rd_party(self, x, model='aliked'):
|
101 |
+
"""
|
102 |
+
one image per batch
|
103 |
+
"""
|
104 |
+
x, rh1, rw1 = self.preprocess_tensor(x)
|
105 |
+
B, _, _H1, _W1 = x.shape
|
106 |
+
if model == 'aliked':
|
107 |
+
from third_party import extract_aliked_kpts
|
108 |
+
img = x
|
109 |
+
mkpts, scores = extract_aliked_kpts(img, self.device)
|
110 |
+
else:
|
111 |
+
raise ValueError('Unknown model')
|
112 |
+
|
113 |
+
M1, _ = self.descriptor(x)
|
114 |
+
M1 = F.normalize(M1, dim=1)
|
115 |
+
|
116 |
+
if mkpts.shape[1] > self.top_k:
|
117 |
+
idx = torch.argsort(scores, descending=True)[0][:self.top_k]
|
118 |
+
mkpts = mkpts[:,idx]
|
119 |
+
scores = scores[:,idx]
|
120 |
+
|
121 |
+
feats = self.interpolator(M1, mkpts, H = _H1, W = _W1)
|
122 |
+
feats = F.normalize(feats, dim=-1)
|
123 |
+
mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1)
|
124 |
+
|
125 |
+
return [
|
126 |
+
{'keypoints': mkpts[b],
|
127 |
+
'scores': scores[b],
|
128 |
+
'descriptors': feats[b]} for b in range(B)
|
129 |
+
]
|
130 |
+
|
131 |
+
@torch.inference_mode()
|
132 |
+
def extract_dense(self, x, n_limit=30000, thr=0.01):
|
133 |
+
self.set_softdetect(top_k=n_limit, scores_th=-1)
|
134 |
+
|
135 |
+
x, rh1, rw1 = self.preprocess_tensor(x)
|
136 |
+
|
137 |
+
B, _, _H1, _W1 = x.shape
|
138 |
+
|
139 |
+
M1, K1, H1 = self.forward(x)
|
140 |
+
M1 = F.normalize(M1, dim=1)
|
141 |
+
|
142 |
+
keypoints, kptscores, scoredispersitys = self.softdetect(K1)
|
143 |
+
|
144 |
+
keypoints = torch.vstack([keypoints[b].unsqueeze(0) for b in range(B)])
|
145 |
+
kptscores = torch.vstack([kptscores[b].unsqueeze(0) for b in range(B)])
|
146 |
+
|
147 |
+
keypoints = to_pixel_coords(keypoints, _H1, _W1)
|
148 |
+
|
149 |
+
feats = self.interpolator(M1, keypoints, H = _H1, W = _W1)
|
150 |
+
|
151 |
+
feats = F.normalize(feats, dim=-1)
|
152 |
+
|
153 |
+
H1 = self.filter(H1)
|
154 |
+
|
155 |
+
dense_kpts, dense_scores, inds = self.sample_dense_kpts(H1, n_limit=n_limit)
|
156 |
+
|
157 |
+
dense_keypoints = to_pixel_coords(dense_kpts, _H1, _W1)
|
158 |
+
|
159 |
+
dense_feats = self.interpolator(M1, dense_keypoints, H = _H1, W = _W1)
|
160 |
+
|
161 |
+
dense_feats = F.normalize(dense_feats, dim=-1)
|
162 |
+
|
163 |
+
keypoints = keypoints * torch.tensor([rw1,rh1], device=keypoints.device).view(1, -1)
|
164 |
+
dense_keypoints = dense_keypoints * torch.tensor([rw1,rh1], device=dense_keypoints.device).view(1, -1)
|
165 |
+
|
166 |
+
valid = kptscores > self.detection_threshold
|
167 |
+
valid_dense = dense_scores > thr
|
168 |
+
|
169 |
+
return [
|
170 |
+
{'keypoints': keypoints[b][valid[b]],
|
171 |
+
'scores': kptscores[b][valid[b]],
|
172 |
+
'descriptors': feats[b][valid[b]],
|
173 |
+
'keypoints_dense': dense_keypoints[b][valid_dense[b]],
|
174 |
+
'scores_dense': dense_scores[b][valid_dense[b]],
|
175 |
+
'descriptors_dense': dense_feats[b][valid_dense[b]]} for b in range(B)
|
176 |
+
]
|
177 |
+
|
178 |
+
@torch.inference_mode()
|
179 |
+
def sample_dense_kpts(self, keypoint_logits, threshold=0.01, n_limit=30000, force_kpts = True):
|
180 |
+
|
181 |
+
B, K, H, W = keypoint_logits.shape
|
182 |
+
|
183 |
+
if n_limit < 0 or n_limit > H*W:
|
184 |
+
n_limit = min(H*W - 1, n_limit)
|
185 |
+
|
186 |
+
scoremap = keypoint_logits.permute(0,2,3,1)
|
187 |
+
|
188 |
+
scoremap = scoremap.reshape(B, H, W)
|
189 |
+
|
190 |
+
frame = torch.zeros(B, H, W, device=keypoint_logits.device)
|
191 |
+
|
192 |
+
frame[:, 1:-1, 1:-1] = 1
|
193 |
+
|
194 |
+
scoremap = scoremap * frame
|
195 |
+
|
196 |
+
scoremap = scoremap.reshape(B, H*W)
|
197 |
+
|
198 |
+
grid = self.get_grid(B, H, W, device = keypoint_logits.device)
|
199 |
+
|
200 |
+
inds = torch.topk(scoremap, n_limit, dim=1).indices
|
201 |
+
|
202 |
+
# inds = torch.multinomial(scoremap, top_k, replacement=False)
|
203 |
+
kpts = torch.gather(grid, 1, inds[..., None].expand(B, n_limit, 2))
|
204 |
+
scoremap = torch.gather(scoremap, 1, inds)
|
205 |
+
if force_kpts:
|
206 |
+
valid = scoremap > threshold
|
207 |
+
kpts = kpts[valid][None]
|
208 |
+
scoremap = scoremap[valid][None]
|
209 |
+
|
210 |
+
return kpts, scoremap, inds
|
211 |
+
|
212 |
+
def preprocess_tensor(self, x):
|
213 |
+
""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
|
214 |
+
if isinstance(x, np.ndarray) and len(x.shape) == 3:
|
215 |
+
x = torch.tensor(x).permute(2,0,1)[None]
|
216 |
+
x = x.to(self.device).float()
|
217 |
+
|
218 |
+
H, W = x.shape[-2:]
|
219 |
+
|
220 |
+
_H, _W = (H//32) * 32, (W//32) * 32
|
221 |
+
|
222 |
+
rh, rw = H/_H, W/_W
|
223 |
+
|
224 |
+
x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
|
225 |
+
return x, rh, rw
|
226 |
+
|
227 |
+
@torch.inference_mode()
|
228 |
+
def get_grid(self, B, H, W, device = None):
|
229 |
+
x1_n = torch.meshgrid(
|
230 |
+
*[
|
231 |
+
torch.linspace(
|
232 |
+
-1 + 1 / n, 1 - 1 / n, n, device=device
|
233 |
+
)
|
234 |
+
for n in (B, H, W)
|
235 |
+
]
|
236 |
+
)
|
237 |
+
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
|
238 |
+
return x1_n
|
239 |
+
|
240 |
+
def build(config=None, weights=None):
|
241 |
+
if config is None:
|
242 |
+
config = read_config('./configs/default.yaml')
|
243 |
+
if weights is not None:
|
244 |
+
config['weights'] = weights
|
245 |
+
device = torch.device(config['device'])
|
246 |
+
print('config', config)
|
247 |
+
detector = build_detector(config)
|
248 |
+
descriptor = build_descriptor(config)
|
249 |
+
model = RDD(
|
250 |
+
detector,
|
251 |
+
descriptor,
|
252 |
+
detection_threshold=config['detection_threshold'],
|
253 |
+
top_k=config['top_k'],
|
254 |
+
train_detector=config['train_detector'],
|
255 |
+
device=device
|
256 |
+
)
|
257 |
+
if 'weights' in config and config['weights'] is not None:
|
258 |
+
model.load_state_dict(torch.load(config['weights'], map_location='cpu'))
|
259 |
+
model.to(device)
|
260 |
+
return model
|
imcui/third_party/rdd/RDD/RDD_helper.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .matchers import DualSoftmaxMatcher, DenseMatcher, LightGlue
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import kornia
|
7 |
+
|
8 |
+
class RDD_helper(nn.Module):
|
9 |
+
def __init__(self, RDD):
|
10 |
+
super().__init__()
|
11 |
+
self.matcher = DualSoftmaxMatcher(inv_temperature = 20, thr = 0.01)
|
12 |
+
self.dense_matcher = DenseMatcher(inv_temperature=20, thr=0.01)
|
13 |
+
self.RDD = RDD
|
14 |
+
self.lg_matcher = None
|
15 |
+
|
16 |
+
@torch.inference_mode()
|
17 |
+
def match(self, img0, img1, thr=0.01, resize=None, top_k=4096):
|
18 |
+
if top_k is not None and top_k != self.RDD.top_k:
|
19 |
+
self.RDD.top_k = top_k
|
20 |
+
self.RDD.set_softdetect(top_k=top_k)
|
21 |
+
|
22 |
+
img0, scale0 = self.parse_input(img0, resize)
|
23 |
+
img1, scale1 = self.parse_input(img1, resize)
|
24 |
+
|
25 |
+
out0 = self.RDD.extract(img0)[0]
|
26 |
+
out1 = self.RDD.extract(img1)[0]
|
27 |
+
|
28 |
+
# get top_k confident matches
|
29 |
+
mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
|
30 |
+
|
31 |
+
scale0 = 1.0 / scale0
|
32 |
+
scale1 = 1.0 / scale1
|
33 |
+
|
34 |
+
mkpts0 = mkpts0 * scale0
|
35 |
+
mkpts1 = mkpts1 * scale1
|
36 |
+
|
37 |
+
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
|
38 |
+
|
39 |
+
@torch.inference_mode()
|
40 |
+
def match_lg(self, img0, img1, thr=0.01, resize=None, top_k=4096):
|
41 |
+
if self.lg_matcher is None:
|
42 |
+
lg_conf = {
|
43 |
+
"name": "lightglue", # just for interfacing
|
44 |
+
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
45 |
+
"descriptor_dim": 256,
|
46 |
+
"add_scale_ori": False,
|
47 |
+
"n_layers": 9,
|
48 |
+
"num_heads": 4,
|
49 |
+
"flash": True, # enable FlashAttention if available.
|
50 |
+
"mp": False, # enable mixed precision
|
51 |
+
"filter_threshold": 0.01, # match threshold
|
52 |
+
"depth_confidence": -1, # depth confidence threshold
|
53 |
+
"width_confidence": -1, # width confidence threshold
|
54 |
+
"weights": './weights/RDD_lg-v2.pth', # path to the weights
|
55 |
+
}
|
56 |
+
self.lg_matcher = LightGlue(features='rdd', conf=lg_conf).to(self.RDD.device)
|
57 |
+
|
58 |
+
if top_k is not None and top_k != self.RDD.top_k:
|
59 |
+
self.RDD.top_k = top_k
|
60 |
+
self.RDD.set_softdetect(top_k=top_k)
|
61 |
+
|
62 |
+
img0, scale0 = self.parse_input(img0, resize=resize)
|
63 |
+
img1, scale1 = self.parse_input(img1, resize=resize)
|
64 |
+
|
65 |
+
size0 = torch.tensor(img0.shape[-2:])[None]
|
66 |
+
size1 = torch.tensor(img1.shape[-2:])[None]
|
67 |
+
|
68 |
+
out0 = self.RDD.extract(img0)[0]
|
69 |
+
out1 = self.RDD.extract(img1)[0]
|
70 |
+
|
71 |
+
# get top_k confident matches
|
72 |
+
image0_data = {
|
73 |
+
'keypoints': out0['keypoints'][None],
|
74 |
+
'descriptors': out0['descriptors'][None],
|
75 |
+
'image_size': size0,
|
76 |
+
}
|
77 |
+
|
78 |
+
image1_data = {
|
79 |
+
'keypoints': out1['keypoints'][None],
|
80 |
+
'descriptors': out1['descriptors'][None],
|
81 |
+
'image_size': size1,
|
82 |
+
}
|
83 |
+
|
84 |
+
pred = {}
|
85 |
+
|
86 |
+
with torch.no_grad():
|
87 |
+
pred.update({'image0': image0_data, 'image1': image1_data})
|
88 |
+
pred.update(self.lg_matcher({**pred}))
|
89 |
+
|
90 |
+
kpts0 = pred['image0']['keypoints'][0]
|
91 |
+
kpts1 = pred['image1']['keypoints'][0]
|
92 |
+
|
93 |
+
matches = pred['matches'][0]
|
94 |
+
|
95 |
+
mkpts0 = kpts0[matches[... , 0]]
|
96 |
+
mkpts1 = kpts1[matches[... , 1]]
|
97 |
+
conf = pred['scores'][0]
|
98 |
+
|
99 |
+
valid_mask = conf > thr
|
100 |
+
mkpts0 = mkpts0[valid_mask]
|
101 |
+
mkpts1 = mkpts1[valid_mask]
|
102 |
+
conf = conf[valid_mask]
|
103 |
+
|
104 |
+
scale0 = 1.0 / scale0
|
105 |
+
scale1 = 1.0 / scale1
|
106 |
+
mkpts0 = mkpts0 * scale0
|
107 |
+
mkpts1 = mkpts1 * scale1
|
108 |
+
|
109 |
+
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
|
110 |
+
|
111 |
+
@torch.inference_mode()
|
112 |
+
def match_dense(self, img0, img1, thr=0.01, resize=None):
|
113 |
+
|
114 |
+
img0, scale0 = self.parse_input(img0, resize=resize)
|
115 |
+
img1, scale1 = self.parse_input(img1, resize=resize)
|
116 |
+
|
117 |
+
out0 = self.RDD.extract_dense(img0)[0]
|
118 |
+
out1 = self.RDD.extract_dense(img1)[0]
|
119 |
+
|
120 |
+
# get top_k confident matches
|
121 |
+
mkpts0, mkpts1, conf = self.dense_matcher(out0, out1, thr, err_thr=self.RDD.stride)
|
122 |
+
|
123 |
+
scale0 = 1.0 / scale0
|
124 |
+
scale1 = 1.0 / scale1
|
125 |
+
|
126 |
+
mkpts0 = mkpts0 * scale0
|
127 |
+
mkpts1 = mkpts1 * scale1
|
128 |
+
|
129 |
+
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
|
130 |
+
|
131 |
+
@torch.inference_mode()
|
132 |
+
def match_3rd_party(self, img0, img1, model='aliked', resize=None, thr=0.01):
|
133 |
+
img0, scale0 = self.parse_input(img0, resize=resize)
|
134 |
+
img1, scale1 = self.parse_input(img1, resize=resize)
|
135 |
+
|
136 |
+
out0 = self.RDD.extract_3rd_party(img0, model=model)[0]
|
137 |
+
out1 = self.RDD.extract_3rd_party(img1, model=model)[0]
|
138 |
+
|
139 |
+
mkpts0, mkpts1, conf = self.matcher(out0, out1, thr)
|
140 |
+
|
141 |
+
scale0 = 1.0 / scale0
|
142 |
+
scale1 = 1.0 / scale1
|
143 |
+
|
144 |
+
mkpts0 = mkpts0 * scale0
|
145 |
+
mkpts1 = mkpts1 * scale1
|
146 |
+
|
147 |
+
return mkpts0.cpu().numpy(), mkpts1.cpu().numpy(), conf.cpu().numpy()
|
148 |
+
|
149 |
+
def parse_input(self, x, resize=None):
|
150 |
+
if len(x.shape) == 3:
|
151 |
+
x = x[None, ...]
|
152 |
+
|
153 |
+
if isinstance(x, np.ndarray):
|
154 |
+
x = torch.tensor(x).permute(0,3,1,2)/255
|
155 |
+
|
156 |
+
h, w = x.shape[-2:]
|
157 |
+
size = h, w
|
158 |
+
|
159 |
+
if resize is not None:
|
160 |
+
size = self.get_new_image_size(h, w, resize)
|
161 |
+
x = kornia.geometry.transform.resize(
|
162 |
+
x,
|
163 |
+
size,
|
164 |
+
side='long',
|
165 |
+
antialias=True,
|
166 |
+
align_corners=None,
|
167 |
+
interpolation='bilinear',
|
168 |
+
)
|
169 |
+
scale = torch.Tensor([x.shape[-1] / w, x.shape[-2] / h]).to(self.RDD.device)
|
170 |
+
|
171 |
+
return x, scale
|
172 |
+
|
173 |
+
def get_new_image_size(self, h, w, resize=1600):
|
174 |
+
aspect_ratio = w / h
|
175 |
+
size = int(resize / aspect_ratio), resize
|
176 |
+
|
177 |
+
size = list(map(lambda x: int(x // 32 * 32), size)) # make sure size is divisible by 32
|
178 |
+
|
179 |
+
return size
|
imcui/third_party/rdd/RDD/dataset/__init__.py
ADDED
File without changes
|
imcui/third_party/rdd/RDD/dataset/megadepth/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .megadepth import *
|
2 |
+
from .megadepth_warper import *
|
imcui/third_party/rdd/RDD/dataset/megadepth/megadepth.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import h5py
|
4 |
+
import torch
|
5 |
+
import pickle
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from tqdm import tqdm
|
9 |
+
from pathlib import Path
|
10 |
+
from torchvision import transforms
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
import cv2
|
14 |
+
from .utils import scale_intrinsics, warp_depth, warp_points2d
|
15 |
+
|
16 |
+
class MegaDepthDataset(Dataset):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
root,
|
20 |
+
npz_path,
|
21 |
+
num_per_scene=100,
|
22 |
+
image_size=256,
|
23 |
+
min_overlap_score=0.1,
|
24 |
+
max_overlap_score=0.9,
|
25 |
+
gray=False,
|
26 |
+
crop_or_scale='scale', # crop, scale, crop_scale
|
27 |
+
train=True,
|
28 |
+
):
|
29 |
+
self.data_path = Path(root)
|
30 |
+
self.num_per_scene = num_per_scene
|
31 |
+
self.train = train
|
32 |
+
self.image_size = image_size
|
33 |
+
self.gray = gray
|
34 |
+
self.crop_or_scale = crop_or_scale
|
35 |
+
|
36 |
+
self.scene_info = dict(np.load(npz_path, allow_pickle=True))
|
37 |
+
self.pair_infos = self.scene_info['pair_infos'].copy()
|
38 |
+
del self.scene_info['pair_infos']
|
39 |
+
self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score]
|
40 |
+
if len(self.pair_infos) > num_per_scene:
|
41 |
+
indices = np.random.choice(len(self.pair_infos), num_per_scene, replace=False)
|
42 |
+
self.pair_infos = [self.pair_infos[idx] for idx in indices]
|
43 |
+
self.transforms = transforms.Compose([transforms.ToPILImage(),
|
44 |
+
transforms.ToTensor()])
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.pair_infos)
|
48 |
+
|
49 |
+
def recover_pair(self, idx):
|
50 |
+
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)]
|
51 |
+
|
52 |
+
img_name1 = self.scene_info['image_paths'][idx0]
|
53 |
+
img_name2 = self.scene_info['image_paths'][idx1]
|
54 |
+
|
55 |
+
depth1 = '/'.join([self.scene_info['depth_paths'][idx0].replace('phoenix/S6/zl548/MegaDepth_v1', 'depth_undistorted').split('/')[i] for i in [0, 1, -1]])
|
56 |
+
depth2 = '/'.join([self.scene_info['depth_paths'][idx1].replace('phoenix/S6/zl548/MegaDepth_v1', 'depth_undistorted').split('/')[i] for i in [0, 1, -1]])
|
57 |
+
|
58 |
+
depth_path1 = self.data_path / depth1
|
59 |
+
with h5py.File(depth_path1, 'r') as hdf5_file:
|
60 |
+
depth1 = np.array(hdf5_file['/depth'])
|
61 |
+
assert (np.min(depth1) >= 0)
|
62 |
+
image_path1 = self.data_path / img_name1
|
63 |
+
image1 = Image.open(image_path1)
|
64 |
+
if image1.mode != 'RGB':
|
65 |
+
image1 = image1.convert('RGB')
|
66 |
+
image1 = np.array(image1)
|
67 |
+
assert (image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
|
68 |
+
intrinsics1 = self.scene_info['intrinsics'][idx0].copy()
|
69 |
+
pose1 = self.scene_info['poses'][idx0]
|
70 |
+
|
71 |
+
depth_path2 = self.data_path / depth2
|
72 |
+
with h5py.File(depth_path2, 'r') as hdf5_file:
|
73 |
+
depth2 = np.array(hdf5_file['/depth'])
|
74 |
+
assert (np.min(depth2) >= 0)
|
75 |
+
image_path2 = self.data_path / img_name2
|
76 |
+
image2 = Image.open(image_path2)
|
77 |
+
if image2.mode != 'RGB':
|
78 |
+
image2 = image2.convert('RGB')
|
79 |
+
image2 = np.array(image2)
|
80 |
+
assert (image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
|
81 |
+
intrinsics2 = self.scene_info['intrinsics'][idx1].copy()
|
82 |
+
pose2 = self.scene_info['poses'][idx1]
|
83 |
+
|
84 |
+
pose12 = pose2 @ np.linalg.inv(pose1)
|
85 |
+
pose21 = np.linalg.inv(pose12)
|
86 |
+
|
87 |
+
if self.train:
|
88 |
+
if "crop" in self.crop_or_scale:
|
89 |
+
# ================================================= compute central_match
|
90 |
+
DOWNSAMPLE = 10
|
91 |
+
# resize to speed up
|
92 |
+
depth1s = cv2.resize(depth1, (depth1.shape[1] // DOWNSAMPLE, depth1.shape[0] // DOWNSAMPLE))
|
93 |
+
depth2s = cv2.resize(depth2, (depth2.shape[1] // DOWNSAMPLE, depth2.shape[0] // DOWNSAMPLE))
|
94 |
+
intrinsic1s = scale_intrinsics(intrinsics1, (DOWNSAMPLE, DOWNSAMPLE))
|
95 |
+
intrinsic2s = scale_intrinsics(intrinsics2, (DOWNSAMPLE, DOWNSAMPLE))
|
96 |
+
|
97 |
+
# warp
|
98 |
+
depth12s = warp_depth(depth1s, intrinsic1s, intrinsic2s, pose12, depth2s.shape)
|
99 |
+
depth21s = warp_depth(depth2s, intrinsic2s, intrinsic1s, pose21, depth1s.shape)
|
100 |
+
|
101 |
+
depth12s[depth12s < 0] = 0
|
102 |
+
depth21s[depth21s < 0] = 0
|
103 |
+
|
104 |
+
valid12s = np.logical_and(depth12s > 0, depth2s > 0)
|
105 |
+
valid21s = np.logical_and(depth21s > 0, depth1s > 0)
|
106 |
+
|
107 |
+
pos1 = np.array(valid21s.nonzero())
|
108 |
+
try:
|
109 |
+
idx1_random = np.random.choice(np.arange(pos1.shape[1]), 1)
|
110 |
+
uv1s = pos1[:, idx1_random][[1, 0]].reshape(1, 2)
|
111 |
+
d1s = np.array(depth1s[uv1s[0, 1], uv1s[0, 0]]).reshape(1, 1)
|
112 |
+
|
113 |
+
uv12s, z12s = warp_points2d(uv1s, d1s, intrinsic1s, intrinsic2s, pose12)
|
114 |
+
|
115 |
+
uv1 = uv1s[0] * DOWNSAMPLE
|
116 |
+
uv2 = uv12s[0] * DOWNSAMPLE
|
117 |
+
except ValueError:
|
118 |
+
uv1 = [depth1.shape[1] / 2, depth1.shape[0] / 2]
|
119 |
+
uv2 = [depth2.shape[1] / 2, depth2.shape[0] / 2]
|
120 |
+
|
121 |
+
central_match = [uv1[1], uv1[0], uv2[1], uv2[0]]
|
122 |
+
# ================================================= compute central_match
|
123 |
+
|
124 |
+
if self.crop_or_scale == 'crop':
|
125 |
+
# =============== padding
|
126 |
+
h1, w1, _ = image1.shape
|
127 |
+
h2, w2, _ = image2.shape
|
128 |
+
if h1 < self.image_size:
|
129 |
+
padding = np.zeros((self.image_size - h1, w1, 3))
|
130 |
+
image1 = np.concatenate([image1, padding], axis=0).astype(np.uint8)
|
131 |
+
depth1 = np.concatenate([depth1, padding[:, :, 0]], axis=0).astype(np.float32)
|
132 |
+
h1, w1, _ = image1.shape
|
133 |
+
if w1 < self.image_size:
|
134 |
+
padding = np.zeros((h1, self.image_size - w1, 3))
|
135 |
+
image1 = np.concatenate([image1, padding], axis=1).astype(np.uint8)
|
136 |
+
depth1 = np.concatenate([depth1, padding[:, :, 0]], axis=1).astype(np.float32)
|
137 |
+
if h2 < self.image_size:
|
138 |
+
padding = np.zeros((self.image_size - h2, w2, 3))
|
139 |
+
image2 = np.concatenate([image2, padding], axis=0).astype(np.uint8)
|
140 |
+
depth2 = np.concatenate([depth2, padding[:, :, 0]], axis=0).astype(np.float32)
|
141 |
+
h2, w2, _ = image2.shape
|
142 |
+
if w2 < self.image_size:
|
143 |
+
padding = np.zeros((h2, self.image_size - w2, 3))
|
144 |
+
image2 = np.concatenate([image2, padding], axis=1).astype(np.uint8)
|
145 |
+
depth2 = np.concatenate([depth2, padding[:, :, 0]], axis=1).astype(np.float32)
|
146 |
+
# =============== padding
|
147 |
+
image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
|
148 |
+
|
149 |
+
depth1 = depth1[bbox1[0]: bbox1[0] + self.image_size, bbox1[1]: bbox1[1] + self.image_size]
|
150 |
+
depth2 = depth2[bbox2[0]: bbox2[0] + self.image_size, bbox2[1]: bbox2[1] + self.image_size]
|
151 |
+
elif self.crop_or_scale == 'scale':
|
152 |
+
image1, depth1, intrinsics1 = self.scale(image1, depth1, intrinsics1)
|
153 |
+
image2, depth2, intrinsics2 = self.scale(image2, depth2, intrinsics2)
|
154 |
+
bbox1 = bbox2 = np.array([0., 0.])
|
155 |
+
elif self.crop_or_scale == 'crop_scale':
|
156 |
+
bbox1 = bbox2 = np.array([0., 0.])
|
157 |
+
image1, depth1, intrinsics1 = self.crop_scale(image1, depth1, intrinsics1, central_match[:2])
|
158 |
+
image2, depth2, intrinsics2 = self.crop_scale(image2, depth2, intrinsics2, central_match[2:])
|
159 |
+
else:
|
160 |
+
raise RuntimeError(f"Unkown type {self.crop_or_scale}")
|
161 |
+
else:
|
162 |
+
bbox1 = bbox2 = np.array([0., 0.])
|
163 |
+
|
164 |
+
return (image1, depth1, intrinsics1, pose12, bbox1,
|
165 |
+
image2, depth2, intrinsics2, pose21, bbox2)
|
166 |
+
|
167 |
+
def scale(self, image, depth, intrinsic):
|
168 |
+
img_size_org = image.shape
|
169 |
+
image = cv2.resize(image, (self.image_size, self.image_size))
|
170 |
+
depth = cv2.resize(depth, (self.image_size, self.image_size))
|
171 |
+
intrinsic = scale_intrinsics(intrinsic, (img_size_org[1] / self.image_size, img_size_org[0] / self.image_size))
|
172 |
+
return image, depth, intrinsic
|
173 |
+
|
174 |
+
def crop_scale(self, image, depth, intrinsic, centeral):
|
175 |
+
h_org, w_org, three = image.shape
|
176 |
+
image_size = min(h_org, w_org)
|
177 |
+
if h_org > w_org:
|
178 |
+
if centeral[1] - image_size // 2 < 0:
|
179 |
+
h_start = 0
|
180 |
+
elif centeral[1] + image_size // 2 > h_org:
|
181 |
+
h_start = h_org - image_size
|
182 |
+
else:
|
183 |
+
h_start = int(centeral[1]) - image_size // 2
|
184 |
+
w_start = 0
|
185 |
+
else:
|
186 |
+
if centeral[0] - image_size // 2 < 0:
|
187 |
+
w_start = 0
|
188 |
+
elif centeral[0] + image_size // 2 > w_org:
|
189 |
+
w_start = w_org - image_size
|
190 |
+
else:
|
191 |
+
w_start = int(centeral[0]) - image_size // 2
|
192 |
+
h_start = 0
|
193 |
+
|
194 |
+
croped_image = image[h_start: h_start + image_size, w_start: w_start + image_size]
|
195 |
+
croped_depth = depth[h_start: h_start + image_size, w_start: w_start + image_size]
|
196 |
+
intrinsic[0, 2] = intrinsic[0, 2] - w_start
|
197 |
+
intrinsic[1, 2] = intrinsic[1, 2] - h_start
|
198 |
+
|
199 |
+
image = cv2.resize(croped_image, (self.image_size, self.image_size))
|
200 |
+
depth = cv2.resize(croped_depth, (self.image_size, self.image_size))
|
201 |
+
intrinsic = scale_intrinsics(intrinsic, (image_size / self.image_size, image_size / self.image_size))
|
202 |
+
|
203 |
+
return image, depth, intrinsic
|
204 |
+
|
205 |
+
def crop(self, image1, image2, central_match):
|
206 |
+
bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
|
207 |
+
if bbox1_i + self.image_size >= image1.shape[0]:
|
208 |
+
bbox1_i = image1.shape[0] - self.image_size
|
209 |
+
bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
|
210 |
+
if bbox1_j + self.image_size >= image1.shape[1]:
|
211 |
+
bbox1_j = image1.shape[1] - self.image_size
|
212 |
+
|
213 |
+
bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
|
214 |
+
if bbox2_i + self.image_size >= image2.shape[0]:
|
215 |
+
bbox2_i = image2.shape[0] - self.image_size
|
216 |
+
bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
|
217 |
+
if bbox2_j + self.image_size >= image2.shape[1]:
|
218 |
+
bbox2_j = image2.shape[1] - self.image_size
|
219 |
+
|
220 |
+
return (image1[bbox1_i: bbox1_i + self.image_size, bbox1_j: bbox1_j + self.image_size],
|
221 |
+
np.array([bbox1_i, bbox1_j]),
|
222 |
+
image2[bbox2_i: bbox2_i + self.image_size, bbox2_j: bbox2_j + self.image_size],
|
223 |
+
np.array([bbox2_i, bbox2_j])
|
224 |
+
)
|
225 |
+
|
226 |
+
def __getitem__(self, idx):
|
227 |
+
(image1, depth1, intrinsics1, pose12, bbox1,
|
228 |
+
image2, depth2, intrinsics2, pose21, bbox2) \
|
229 |
+
= self.recover_pair(idx)
|
230 |
+
|
231 |
+
if self.gray:
|
232 |
+
gray1 = cv2.cvtColor(image1, cv2.COLOR_RGB2GRAY)
|
233 |
+
gray2 = cv2.cvtColor(image2, cv2.COLOR_RGB2GRAY)
|
234 |
+
gray1 = transforms.ToTensor()(gray1)
|
235 |
+
gray2 = transforms.ToTensor()(gray2)
|
236 |
+
if self.transforms is not None:
|
237 |
+
image1, image2 = self.transforms(image1), self.transforms(image2) # [C,H,W]
|
238 |
+
ret = {'image0': image1,
|
239 |
+
'image1': image2,
|
240 |
+
'angle': 0,
|
241 |
+
'overlap': self.pair_infos[idx][1],
|
242 |
+
'warp01_params': {'mode': 'se3',
|
243 |
+
'width': self.image_size if self.train else image1.shape[2],
|
244 |
+
'height': self.image_size if self.train else image1.shape[1],
|
245 |
+
'pose01': torch.from_numpy(pose12.astype(np.float32)),
|
246 |
+
'bbox0': torch.from_numpy(bbox1.astype(np.float32)),
|
247 |
+
'bbox1': torch.from_numpy(bbox2.astype(np.float32)),
|
248 |
+
'depth0': torch.from_numpy(depth1.astype(np.float32)),
|
249 |
+
'depth1': torch.from_numpy(depth2.astype(np.float32)),
|
250 |
+
'intrinsics0': torch.from_numpy(intrinsics1.astype(np.float32)),
|
251 |
+
'intrinsics1': torch.from_numpy(intrinsics2.astype(np.float32))},
|
252 |
+
'warp10_params': {'mode': 'se3',
|
253 |
+
'width': self.image_size if self.train else image2.shape[2],
|
254 |
+
'height': self.image_size if self.train else image2.shape[2],
|
255 |
+
'pose01': torch.from_numpy(pose21.astype(np.float32)),
|
256 |
+
'bbox0': torch.from_numpy(bbox2.astype(np.float32)),
|
257 |
+
'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
|
258 |
+
'depth0': torch.from_numpy(depth2.astype(np.float32)),
|
259 |
+
'depth1': torch.from_numpy(depth1.astype(np.float32)),
|
260 |
+
'intrinsics0': torch.from_numpy(intrinsics2.astype(np.float32)),
|
261 |
+
'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32))},
|
262 |
+
}
|
263 |
+
if self.gray:
|
264 |
+
ret['gray0'] = gray1
|
265 |
+
ret['gray1'] = gray2
|
266 |
+
return ret
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == '__main__':
|
270 |
+
from torch.utils.data import DataLoader
|
271 |
+
import matplotlib.pyplot as plt
|
272 |
+
|
273 |
+
|
274 |
+
def visualize(image0, image1, depth0, depth1):
|
275 |
+
# visualize image and depth
|
276 |
+
plt.figure(figsize=(9, 9))
|
277 |
+
plt.subplot(2, 2, 1)
|
278 |
+
plt.imshow(image0, cmap='gray')
|
279 |
+
plt.subplot(2, 2, 2)
|
280 |
+
plt.imshow(depth0)
|
281 |
+
plt.subplot(2, 2, 3)
|
282 |
+
plt.imshow(image1, cmap='gray')
|
283 |
+
plt.subplot(2, 2, 4)
|
284 |
+
plt.imshow(depth1)
|
285 |
+
plt.show()
|
286 |
+
|
287 |
+
|
288 |
+
dataset = MegaDepthDataset( # root='../data/megadepth',
|
289 |
+
root='../data/imw2020val',
|
290 |
+
train=False,
|
291 |
+
using_cache=True,
|
292 |
+
pairs_per_scene=100,
|
293 |
+
image_size=256,
|
294 |
+
colorjit=True,
|
295 |
+
gray=False,
|
296 |
+
crop_or_scale='scale',
|
297 |
+
)
|
298 |
+
dataset.build_dataset()
|
299 |
+
|
300 |
+
batch_size = 2
|
301 |
+
|
302 |
+
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
|
303 |
+
|
304 |
+
for idx, batch in enumerate(tqdm(loader)):
|
305 |
+
image0, image1 = batch['image0'], batch['image1'] # [B,3,H,W]
|
306 |
+
depth0, depth1 = batch['warp01_params']['depth0'], batch['warp01_params']['depth1'] # [B,H,W]
|
307 |
+
intrinsics0, intrinsics1 = batch['warp01_params']['intrinsics0'], batch['warp01_params'][
|
308 |
+
'intrinsics1'] # [B,3,3]
|
309 |
+
|
310 |
+
batch_size, channels, h, w = image0.shape
|
311 |
+
|
312 |
+
for b_idx in range(batch_size):
|
313 |
+
visualize(image0[b_idx].permute(1, 2, 0), image1[b_idx].permute(1, 2, 0), depth0[b_idx], depth1[b_idx])
|
imcui/third_party/rdd/RDD/dataset/megadepth/megadepth_warper.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from kornia.utils import create_meshgrid
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import pdb
|
5 |
+
from .utils import warp
|
6 |
+
|
7 |
+
@torch.no_grad()
|
8 |
+
def spvs_coarse(data, scale = 8):
|
9 |
+
N, _, H0, W0 = data['image0'].shape
|
10 |
+
_, _, H1, W1 = data['image1'].shape
|
11 |
+
device = data['image0'].device
|
12 |
+
corrs = []
|
13 |
+
for idx in range(N):
|
14 |
+
warp01_params = {}
|
15 |
+
for k, v in data['warp01_params'].items():
|
16 |
+
if isinstance(v[idx], torch.Tensor):
|
17 |
+
warp01_params[k] = v[idx].to(device)
|
18 |
+
else:
|
19 |
+
warp01_params[k] = v[idx]
|
20 |
+
warp10_params = {}
|
21 |
+
for k, v in data['warp10_params'].items():
|
22 |
+
if isinstance(v[idx], torch.Tensor):
|
23 |
+
warp10_params[k] = v[idx].to(device)
|
24 |
+
else:
|
25 |
+
warp10_params[k] = v[idx]
|
26 |
+
|
27 |
+
# create kpts
|
28 |
+
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
29 |
+
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(h1*w1, 2) # [N, hw, 2]
|
30 |
+
|
31 |
+
# normalize kpts
|
32 |
+
grid_pt1_c = grid_pt1_c * scale
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
grid_pt1_c_valid, grid_pt10_c, ids1, ids1_out = warp(grid_pt1_c, warp10_params)
|
38 |
+
grid_pt10_c_valid, grid_pt01_c, ids0, ids0_out = warp(grid_pt10_c, warp01_params)
|
39 |
+
|
40 |
+
# check reproj error
|
41 |
+
grid_pt1_c_valid = grid_pt1_c_valid[ids0]
|
42 |
+
dist = torch.linalg.norm(grid_pt1_c_valid - grid_pt01_c, dim=-1)
|
43 |
+
|
44 |
+
mask_mutual = (dist < 1.5)
|
45 |
+
|
46 |
+
#get correspondences
|
47 |
+
pts = torch.cat([grid_pt10_c_valid[mask_mutual] / scale,
|
48 |
+
grid_pt01_c[mask_mutual] / scale], dim=-1)
|
49 |
+
#remove repeated correspondences
|
50 |
+
lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1
|
51 |
+
lut_mat21 = torch.clone(lut_mat12)
|
52 |
+
src_pts = pts[:, :2]
|
53 |
+
tgt_pts = pts[:, 2:]
|
54 |
+
|
55 |
+
lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
|
56 |
+
mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1)
|
57 |
+
points = lut_mat12[mask_valid12]
|
58 |
+
|
59 |
+
#Target-src check
|
60 |
+
src_pts, tgt_pts = points[:, :2], points[:, 2:]
|
61 |
+
lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
|
62 |
+
mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1)
|
63 |
+
points = lut_mat21[mask_valid21]
|
64 |
+
|
65 |
+
corrs.append(points)
|
66 |
+
except:
|
67 |
+
corrs.append(torch.zeros((0, 4), device = device))
|
68 |
+
#pdb.set_trace()
|
69 |
+
#print('..')
|
70 |
+
|
71 |
+
#Plot for debug purposes
|
72 |
+
# for i in range(len(corrs)):
|
73 |
+
# plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8)
|
74 |
+
|
75 |
+
return corrs
|
imcui/third_party/rdd/RDD/dataset/megadepth/utils.py
ADDED
@@ -0,0 +1,848 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
|
3 |
+
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
|
4 |
+
|
5 |
+
MegaDepth data handling was adapted from
|
6 |
+
LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
|
7 |
+
"""
|
8 |
+
|
9 |
+
import io
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import h5py
|
13 |
+
import torch
|
14 |
+
from numpy.linalg import inv
|
15 |
+
from kornia.geometry.epipolar import essential_from_Rt
|
16 |
+
from kornia.geometry.epipolar import fundamental_from_essential
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import torch
|
20 |
+
import numpy as np
|
21 |
+
from numba import jit
|
22 |
+
from copy import deepcopy
|
23 |
+
|
24 |
+
try:
|
25 |
+
from utils.project_depth_nn_cython_pkg import project_depth_nn_cython
|
26 |
+
|
27 |
+
nn_cython = True
|
28 |
+
except:
|
29 |
+
print('\033[1;41;37mWarning: using python to project depth!!!\033[0m')
|
30 |
+
|
31 |
+
nn_cython = False
|
32 |
+
|
33 |
+
|
34 |
+
class EmptyTensorError(Exception):
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
def mutual_NN(cross_matrix, mode: str = 'min'):
|
39 |
+
"""
|
40 |
+
compute mutual nearest neighbor from a cross_matrix, non-differentiable function
|
41 |
+
:param cross_matrix: N0xN1
|
42 |
+
:param mode: 'min': mutual minimum; 'max':mutual maximum
|
43 |
+
:return: index0,index1, Mx2
|
44 |
+
"""
|
45 |
+
if mode == 'min':
|
46 |
+
nn0 = cross_matrix == cross_matrix.min(dim=1, keepdim=True)[0]
|
47 |
+
nn1 = cross_matrix == cross_matrix.min(dim=0, keepdim=True)[0]
|
48 |
+
elif mode == 'max':
|
49 |
+
nn0 = cross_matrix == cross_matrix.max(dim=1, keepdim=True)[0]
|
50 |
+
nn1 = cross_matrix == cross_matrix.max(dim=0, keepdim=True)[0]
|
51 |
+
else:
|
52 |
+
raise TypeError("error mode, must be 'min' or 'max'.")
|
53 |
+
|
54 |
+
mutual_nn = nn0 * nn1
|
55 |
+
|
56 |
+
return torch.nonzero(mutual_nn, as_tuple=False)
|
57 |
+
|
58 |
+
|
59 |
+
def mutual_argmax(value, mask=None, as_tuple=True):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
value: MxN
|
63 |
+
mask: MxN
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
|
67 |
+
"""
|
68 |
+
value = value - value.min() # convert to non-negative tensor
|
69 |
+
if mask is not None:
|
70 |
+
value = value * mask
|
71 |
+
|
72 |
+
max0 = value.max(dim=1, keepdim=True) # the col index the max value in each row
|
73 |
+
max1 = value.max(dim=0, keepdim=True)
|
74 |
+
|
75 |
+
valid_max0 = value == max0[0]
|
76 |
+
valid_max1 = value == max1[0]
|
77 |
+
|
78 |
+
mutual = valid_max0 * valid_max1
|
79 |
+
if mask is not None:
|
80 |
+
mutual = mutual * mask
|
81 |
+
|
82 |
+
return mutual.nonzero(as_tuple=as_tuple)
|
83 |
+
|
84 |
+
|
85 |
+
def mutual_argmin(value, mask=None):
|
86 |
+
return mutual_argmax(-value, mask)
|
87 |
+
|
88 |
+
|
89 |
+
def compute_keypoints_distance(kpts0, kpts1, p=2):
|
90 |
+
"""
|
91 |
+
Args:
|
92 |
+
kpts0: torch.tensor [M,2]
|
93 |
+
kpts1: torch.tensor [N,2]
|
94 |
+
p: (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
dist, torch.tensor [N,M]
|
98 |
+
"""
|
99 |
+
dist = kpts0[:, None, :] - kpts1[None, :, :] # [M,N,2]
|
100 |
+
dist = torch.norm(dist, p=p, dim=2) # [M,N]
|
101 |
+
return dist
|
102 |
+
|
103 |
+
|
104 |
+
def keypoints_normal2pixel(kpts_normal, w, h):
|
105 |
+
wh = kpts_normal[0].new_tensor([[w - 1, h - 1]])
|
106 |
+
kpts_pixel = [(kpts + 1) / 2 * wh for kpts in kpts_normal]
|
107 |
+
return kpts_pixel
|
108 |
+
|
109 |
+
|
110 |
+
def plot_keypoints(image, kpts, radius=2, color=(255, 0, 0)):
|
111 |
+
image = image.cpu().detach().numpy() if isinstance(image, torch.Tensor) else image
|
112 |
+
kpts = kpts.cpu().detach().numpy() if isinstance(kpts, torch.Tensor) else kpts
|
113 |
+
|
114 |
+
if image.dtype is not np.dtype('uint8'):
|
115 |
+
image = image * 255
|
116 |
+
image = image.astype(np.uint8)
|
117 |
+
|
118 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
119 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
120 |
+
|
121 |
+
out = np.ascontiguousarray(deepcopy(image))
|
122 |
+
kpts = np.round(kpts).astype(int)
|
123 |
+
|
124 |
+
for kpt in kpts:
|
125 |
+
y0, x0 = kpt
|
126 |
+
cv2.drawMarker(out, (x0, y0), color, cv2.MARKER_CROSS, radius)
|
127 |
+
|
128 |
+
# cv2.circle(out, (x0, y0), radius, color, -1, lineType=cv2.LINE_4)
|
129 |
+
return out
|
130 |
+
|
131 |
+
|
132 |
+
def plot_matches(image0, image1, kpts0, kpts1, radius=2, color=(255, 0, 0), mcolor=(0, 255, 0), layout='lr'):
|
133 |
+
image0 = image0.cpu().detach().numpy() if isinstance(image0, torch.Tensor) else image0
|
134 |
+
image1 = image1.cpu().detach().numpy() if isinstance(image1, torch.Tensor) else image1
|
135 |
+
kpts0 = kpts0.cpu().detach().numpy() if isinstance(kpts0, torch.Tensor) else kpts0
|
136 |
+
kpts1 = kpts1.cpu().detach().numpy() if isinstance(kpts1, torch.Tensor) else kpts1
|
137 |
+
|
138 |
+
out0 = plot_keypoints(image0, kpts0, radius, color)
|
139 |
+
out1 = plot_keypoints(image1, kpts1, radius, color)
|
140 |
+
|
141 |
+
H0, W0 = image0.shape[0], image0.shape[1]
|
142 |
+
H1, W1 = image1.shape[0], image1.shape[1]
|
143 |
+
|
144 |
+
if layout == "lr":
|
145 |
+
H, W = max(H0, H1), W0 + W1
|
146 |
+
out = 255 * np.ones((H, W, 3), np.uint8)
|
147 |
+
out[:H0, :W0, :] = out0
|
148 |
+
out[:H1, W0:, :] = out1
|
149 |
+
elif layout == "ud":
|
150 |
+
H, W = H0 + H1, max(W0, W1)
|
151 |
+
out = 255 * np.ones((H, W, 3), np.uint8)
|
152 |
+
out[:H0, :W0, :] = out0
|
153 |
+
out[H0:, :W1, :] = out1
|
154 |
+
else:
|
155 |
+
raise ValueError("The layout must be 'lr' or 'ud'!")
|
156 |
+
|
157 |
+
kpts0 = np.round(kpts0).astype(int)
|
158 |
+
kpts1 = np.round(kpts1).astype(int)
|
159 |
+
|
160 |
+
for kpt0, kpt1 in zip(kpts0, kpts1):
|
161 |
+
(y0, x0), (y1, x1) = kpt0, kpt1
|
162 |
+
|
163 |
+
if layout == "lr":
|
164 |
+
cv2.line(out, (x0, y0), (x1 + W0, y1), color=mcolor, thickness=1, lineType=cv2.LINE_AA)
|
165 |
+
elif layout == "ud":
|
166 |
+
cv2.line(out, (x0, y0), (x1, y1 + H0), color=mcolor, thickness=1, lineType=cv2.LINE_AA)
|
167 |
+
|
168 |
+
return out
|
169 |
+
|
170 |
+
|
171 |
+
def interpolate_depth(pos, depth):
|
172 |
+
pos = pos.t()[[1, 0]] # Nx2 -> 2xN; w,h -> h,w(i,j)
|
173 |
+
|
174 |
+
# =============================================== from d2-net
|
175 |
+
device = pos.device
|
176 |
+
|
177 |
+
ids = torch.arange(0, pos.size(1), device=device)
|
178 |
+
|
179 |
+
h, w = depth.size()
|
180 |
+
|
181 |
+
i = pos[0, :].detach() # TODO: changed here
|
182 |
+
j = pos[1, :].detach() # TODO: changed here
|
183 |
+
|
184 |
+
# Valid corners
|
185 |
+
i_top_left = torch.floor(i).long()
|
186 |
+
j_top_left = torch.floor(j).long()
|
187 |
+
valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
|
188 |
+
|
189 |
+
i_top_right = torch.floor(i).long()
|
190 |
+
j_top_right = torch.ceil(j).long()
|
191 |
+
valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
|
192 |
+
|
193 |
+
i_bottom_left = torch.ceil(i).long()
|
194 |
+
j_bottom_left = torch.floor(j).long()
|
195 |
+
valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
|
196 |
+
|
197 |
+
i_bottom_right = torch.ceil(i).long()
|
198 |
+
j_bottom_right = torch.ceil(j).long()
|
199 |
+
valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
|
200 |
+
|
201 |
+
valid_corners = torch.min(torch.min(valid_top_left, valid_top_right),
|
202 |
+
torch.min(valid_bottom_left, valid_bottom_right))
|
203 |
+
|
204 |
+
i_top_left = i_top_left[valid_corners]
|
205 |
+
j_top_left = j_top_left[valid_corners]
|
206 |
+
|
207 |
+
i_top_right = i_top_right[valid_corners]
|
208 |
+
j_top_right = j_top_right[valid_corners]
|
209 |
+
|
210 |
+
i_bottom_left = i_bottom_left[valid_corners]
|
211 |
+
j_bottom_left = j_bottom_left[valid_corners]
|
212 |
+
|
213 |
+
i_bottom_right = i_bottom_right[valid_corners]
|
214 |
+
j_bottom_right = j_bottom_right[valid_corners]
|
215 |
+
|
216 |
+
ids = ids[valid_corners]
|
217 |
+
ids_valid_corners = deepcopy(ids)
|
218 |
+
if ids.size(0) == 0:
|
219 |
+
# raise ValueError('empty tensor: ids')
|
220 |
+
raise EmptyTensorError
|
221 |
+
|
222 |
+
# Valid depth
|
223 |
+
valid_depth = torch.min(torch.min(depth[i_top_left, j_top_left] > 0,
|
224 |
+
depth[i_top_right, j_top_right] > 0),
|
225 |
+
torch.min(depth[i_bottom_left, j_bottom_left] > 0,
|
226 |
+
depth[i_bottom_right, j_bottom_right] > 0))
|
227 |
+
|
228 |
+
i_top_left = i_top_left[valid_depth]
|
229 |
+
j_top_left = j_top_left[valid_depth]
|
230 |
+
|
231 |
+
i_top_right = i_top_right[valid_depth]
|
232 |
+
j_top_right = j_top_right[valid_depth]
|
233 |
+
|
234 |
+
i_bottom_left = i_bottom_left[valid_depth]
|
235 |
+
j_bottom_left = j_bottom_left[valid_depth]
|
236 |
+
|
237 |
+
i_bottom_right = i_bottom_right[valid_depth]
|
238 |
+
j_bottom_right = j_bottom_right[valid_depth]
|
239 |
+
|
240 |
+
ids = ids[valid_depth]
|
241 |
+
ids_valid_depth = deepcopy(ids)
|
242 |
+
if ids.size(0) == 0:
|
243 |
+
# raise ValueError('empty tensor: ids')
|
244 |
+
raise EmptyTensorError
|
245 |
+
|
246 |
+
# Interpolation
|
247 |
+
i = i[ids]
|
248 |
+
j = j[ids]
|
249 |
+
dist_i_top_left = i - i_top_left.float()
|
250 |
+
dist_j_top_left = j - j_top_left.float()
|
251 |
+
w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
|
252 |
+
w_top_right = (1 - dist_i_top_left) * dist_j_top_left
|
253 |
+
w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
|
254 |
+
w_bottom_right = dist_i_top_left * dist_j_top_left
|
255 |
+
|
256 |
+
interpolated_depth = (w_top_left * depth[i_top_left, j_top_left] +
|
257 |
+
w_top_right * depth[i_top_right, j_top_right] +
|
258 |
+
w_bottom_left * depth[i_bottom_left, j_bottom_left] +
|
259 |
+
w_bottom_right * depth[i_bottom_right, j_bottom_right])
|
260 |
+
|
261 |
+
# pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
|
262 |
+
|
263 |
+
pos = pos[:, ids]
|
264 |
+
|
265 |
+
# =============================================== from d2-net
|
266 |
+
pos = pos[[1, 0]].t() # 2xN -> Nx2; h,w(i,j) -> w,h
|
267 |
+
|
268 |
+
# interpolated_depth: valid interpolated depth
|
269 |
+
# pos: valid position (keypoint)
|
270 |
+
# ids: indices of valid position (keypoint)
|
271 |
+
|
272 |
+
return [interpolated_depth, pos, ids, ids_valid_corners, ids_valid_depth]
|
273 |
+
|
274 |
+
|
275 |
+
def to_homogeneous(kpts):
|
276 |
+
'''
|
277 |
+
:param kpts: Nx2
|
278 |
+
:return: Nx3
|
279 |
+
'''
|
280 |
+
ones = kpts.new_ones([kpts.shape[0], 1])
|
281 |
+
return torch.cat((kpts, ones), dim=1)
|
282 |
+
|
283 |
+
|
284 |
+
def warp_homography(kpts0, params):
|
285 |
+
'''
|
286 |
+
:param kpts: Nx2
|
287 |
+
:param homography_matrix: 3x3
|
288 |
+
:return:
|
289 |
+
'''
|
290 |
+
homography_matrix = params['homography_matrix']
|
291 |
+
w, h = params['width'], params['height']
|
292 |
+
kpts0_homogeneous = to_homogeneous(kpts0)
|
293 |
+
kpts01_homogeneous = torch.einsum('ij,kj->ki', homography_matrix, kpts0_homogeneous)
|
294 |
+
kpts01 = kpts01_homogeneous[:, :2] / kpts01_homogeneous[:, 2:]
|
295 |
+
|
296 |
+
kpts01_ = kpts01.detach()
|
297 |
+
# due to float coordinates, the upper boundary should be (w-1) and (h-1).
|
298 |
+
# For example, if the image size is 480, then the coordinates should in [0~470].
|
299 |
+
# 470.5 is not acceptable.
|
300 |
+
valid01 = (kpts01_[:, 0] >= 0) * (kpts01_[:, 0] <= w - 1) * (kpts01_[:, 1] >= 0) * (kpts01_[:, 1] <= h - 1)
|
301 |
+
kpts0_valid = kpts0[valid01]
|
302 |
+
kpts01_valid = kpts01[valid01]
|
303 |
+
ids = torch.nonzero(valid01, as_tuple=False)[:, 0]
|
304 |
+
ids_out = torch.nonzero(~valid01, as_tuple=False)[:, 0]
|
305 |
+
|
306 |
+
# kpts0_valid: valid keypoints0, the invalid and inconsistance keypoints are removed
|
307 |
+
# kpts01_valid: the warped valid keypoints0
|
308 |
+
# ids: the valid indices
|
309 |
+
return kpts0_valid, kpts01_valid, ids, ids_out
|
310 |
+
|
311 |
+
|
312 |
+
def project(points3d, K):
|
313 |
+
"""
|
314 |
+
project 3D points to image plane
|
315 |
+
|
316 |
+
Args:
|
317 |
+
points3d: [N,3]
|
318 |
+
K: [3,3]
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
uv, (u,v), [N,2]
|
322 |
+
"""
|
323 |
+
if type(K) == torch.Tensor:
|
324 |
+
zuv1 = torch.einsum('jk,nk->nj', K, points3d) # z*(u,v,1) = K*points3d -> [N,3]
|
325 |
+
elif type(K) == np.ndarray:
|
326 |
+
zuv1 = np.einsum('jk,nk->nj', K, points3d)
|
327 |
+
else:
|
328 |
+
raise TypeError("Input type should be 'torch.tensor' or 'numpy.ndarray'")
|
329 |
+
uv1 = zuv1 / zuv1[:, -1][:, None] # (u,v,1) -> [N,3]
|
330 |
+
uv = uv1[:, 0:2] # (u,v) -> [N,2]
|
331 |
+
return uv, zuv1[:, -1]
|
332 |
+
|
333 |
+
|
334 |
+
def unproject(uv, d, K):
|
335 |
+
"""
|
336 |
+
unproject pixels uv to 3D points
|
337 |
+
|
338 |
+
Args:
|
339 |
+
uv: [N,2]
|
340 |
+
d: depth, [N,1]
|
341 |
+
K: [3,3]
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
3D points, [N,3]
|
345 |
+
"""
|
346 |
+
duv = uv * d # (u,v) [N,2]
|
347 |
+
if type(K) == torch.Tensor:
|
348 |
+
duv1 = torch.cat([duv, d], dim=1) # z*(u,v,1) [N,3]
|
349 |
+
K_inv = torch.inverse(K) # [3,3]
|
350 |
+
points3d = torch.einsum('jk,nk->nj', K_inv, duv1) # [N,3]
|
351 |
+
elif type(K) == np.ndarray:
|
352 |
+
duv1 = np.concatenate((duv, d), axis=1) # z*(u,v,1) [N,3]
|
353 |
+
K_inv = np.linalg.inv(K) # [3,3]
|
354 |
+
points3d = np.einsum('jk,nk->nj', K_inv, duv1) # [N,3]
|
355 |
+
else:
|
356 |
+
raise TypeError("Input type should be 'torch.tensor' or 'numpy.ndarray'")
|
357 |
+
return points3d
|
358 |
+
|
359 |
+
|
360 |
+
def warp_se3(kpts0, params):
|
361 |
+
pose01 = params['pose01'] # relative motion
|
362 |
+
bbox0 = params['bbox0'] # row, col
|
363 |
+
bbox1 = params['bbox1']
|
364 |
+
depth0 = params['depth0']
|
365 |
+
depth1 = params['depth1']
|
366 |
+
intrinsics0 = params['intrinsics0']
|
367 |
+
intrinsics1 = params['intrinsics1']
|
368 |
+
|
369 |
+
# kpts0_valid: valid kpts0
|
370 |
+
# z0_valid: depth of valid kpts0
|
371 |
+
# ids0: the indices of valid kpts0 ( valid corners and valid depth)
|
372 |
+
# ids0_valid_corners: the valid indices of kpts0 in image ( 0<=x<w, 0<=y<h )
|
373 |
+
# ids0_valid_depth: the valid indices of kpts0 with valid depth ( depth > 0 )
|
374 |
+
z0_valid, kpts0_valid, ids0, ids0_valid_corners, ids0_valid_depth = interpolate_depth(kpts0, depth0)
|
375 |
+
|
376 |
+
# COLMAP convention
|
377 |
+
bkpts0_valid = kpts0_valid + bbox0[[1, 0]][None, :] + 0.5
|
378 |
+
|
379 |
+
# unproject pixel coordinate to 3D points (camera coordinate system)
|
380 |
+
bpoints3d0 = unproject(bkpts0_valid, z0_valid.unsqueeze(1), intrinsics0) # [:,3]
|
381 |
+
bpoints3d0_homo = to_homogeneous(bpoints3d0) # [:,4]
|
382 |
+
|
383 |
+
# warp 3D point (camera 0 coordinate system) to 3D point (camera 1 coordinate system)
|
384 |
+
bpoints3d01_homo = torch.einsum('jk,nk->nj', pose01, bpoints3d0_homo) # [:,4]
|
385 |
+
bpoints3d01 = bpoints3d01_homo[:, 0:3] # [:,3]
|
386 |
+
|
387 |
+
# project 3D point (camera coordinate system) to pixel coordinate
|
388 |
+
buv01, z01 = project(bpoints3d01, intrinsics1) # uv: [:,2], (h,w); z1: [N]
|
389 |
+
|
390 |
+
uv01 = buv01 - bbox1[None, [1, 0]] - .5
|
391 |
+
|
392 |
+
# kpts01_valid: valid kpts01
|
393 |
+
# z01_valid: depth of valid kpts01
|
394 |
+
# ids01: the indices of valid kpts01 ( valid corners and valid depth)
|
395 |
+
# ids01_valid_corners: the valid indices of kpts01 in image ( 0<=x<w, 0<=y<h )
|
396 |
+
# ids01_valid_depth: the valid indices of kpts01 with valid depth ( depth > 0 )
|
397 |
+
z01_interpolate, kpts01_valid, ids01, ids01_valid_corners, ids01_valid_depth = interpolate_depth(uv01, depth1)
|
398 |
+
|
399 |
+
outimage_mask = torch.ones(ids0.shape[0], device=ids0.device).bool()
|
400 |
+
outimage_mask[ids01_valid_corners] = 0
|
401 |
+
ids01_invalid_corners = torch.arange(0, ids0.shape[0], device=ids0.device)[outimage_mask]
|
402 |
+
ids_outside = ids0[ids01_invalid_corners]
|
403 |
+
|
404 |
+
# ids_valid: matched kpts01 without occlusion
|
405 |
+
ids_valid = ids0[ids01]
|
406 |
+
kpts0_valid = kpts0_valid[ids01]
|
407 |
+
z01_proj = z01[ids01]
|
408 |
+
|
409 |
+
inlier_mask = torch.abs(z01_proj - z01_interpolate) < 0.05
|
410 |
+
|
411 |
+
# indices of kpts01 with occlusion
|
412 |
+
ids_occlude = ids_valid[~inlier_mask]
|
413 |
+
|
414 |
+
ids_valid = ids_valid[inlier_mask]
|
415 |
+
if ids_valid.size(0) == 0:
|
416 |
+
# raise ValueError('empty tensor: ids')
|
417 |
+
raise EmptyTensorError
|
418 |
+
|
419 |
+
kpts01_valid = kpts01_valid[inlier_mask]
|
420 |
+
kpts0_valid = kpts0_valid[inlier_mask]
|
421 |
+
|
422 |
+
# indices of kpts01 which are no matches in image1 for sure,
|
423 |
+
# other projected kpts01 are not sure because of no depth in image0 or imgae1
|
424 |
+
ids_out = torch.cat([ids_outside, ids_occlude])
|
425 |
+
|
426 |
+
# kpts0_valid: valid keypoints0, the invalid and inconsistance keypoints are removed
|
427 |
+
# kpts01_valid: the warped valid keypoints0
|
428 |
+
# ids: the valid indices
|
429 |
+
return kpts0_valid, kpts01_valid, ids_valid, ids_out
|
430 |
+
|
431 |
+
|
432 |
+
def warp(kpts0, params: dict):
|
433 |
+
mode = params['mode']
|
434 |
+
if mode == 'homo':
|
435 |
+
return warp_homography(kpts0, params)
|
436 |
+
elif mode == 'se3':
|
437 |
+
return warp_se3(kpts0, params)
|
438 |
+
else:
|
439 |
+
raise ValueError('unknown mode!')
|
440 |
+
|
441 |
+
|
442 |
+
def warp_xy(kpts0_xy, params: dict):
|
443 |
+
w, h = params['width'], params['height']
|
444 |
+
kpts0 = (kpts0_xy / 2 + 0.5) * kpts0_xy.new_tensor([[w - 1, h - 1]])
|
445 |
+
kpts0, kpts01, ids = warp(kpts0, params)
|
446 |
+
kpts01_xy = 2 * kpts01 / kpts01.new_tensor([[w - 1, h - 1]]) - 1
|
447 |
+
kpts0_xy = 2 * kpts0 / kpts0.new_tensor([[w - 1, h - 1]]) - 1
|
448 |
+
return kpts0_xy, kpts01_xy, ids
|
449 |
+
|
450 |
+
|
451 |
+
def scale_intrinsics(K, scales):
|
452 |
+
scales = np.diag([1. / scales[0], 1. / scales[1], 1.])
|
453 |
+
return np.dot(scales, K)
|
454 |
+
|
455 |
+
|
456 |
+
def warp_points3d(points3d0, pose01):
|
457 |
+
points3d0_homo = np.concatenate((points3d0, np.ones(points3d0.shape[0])[:, np.newaxis]), axis=1) # [:,4]
|
458 |
+
|
459 |
+
points3d01_homo = np.einsum('jk,nk->nj', pose01, points3d0_homo) # [N,4]
|
460 |
+
points3d01 = points3d01_homo[:, 0:3] # [N,3]
|
461 |
+
|
462 |
+
return points3d01
|
463 |
+
|
464 |
+
|
465 |
+
def unproject_depth(depth, K):
|
466 |
+
h, w = depth.shape
|
467 |
+
|
468 |
+
wh_range = np.mgrid[0:w, 0:h].transpose(2, 1, 0) # [H,W,2]
|
469 |
+
|
470 |
+
uv = wh_range.reshape(-1, 2)
|
471 |
+
d = depth.reshape(-1, 1)
|
472 |
+
points3d = unproject(uv, d, K)
|
473 |
+
|
474 |
+
valid = np.logical_and((d[:, 0] > 0), (points3d[:, 2] > 0))
|
475 |
+
|
476 |
+
return points3d, valid
|
477 |
+
|
478 |
+
|
479 |
+
@jit(nopython=True)
|
480 |
+
def project_depth_nn_python(uv, z, depth):
|
481 |
+
h, w = depth.shape
|
482 |
+
# TODO: speed up the for loop
|
483 |
+
for idx in range(len(uv)):
|
484 |
+
uvi = uv[idx]
|
485 |
+
x = int(round(uvi[0]))
|
486 |
+
y = int(round(uvi[1]))
|
487 |
+
|
488 |
+
if x < 0 or y < 0 or x >= w or y >= h:
|
489 |
+
continue
|
490 |
+
|
491 |
+
if depth[y, x] == 0. or depth[y, x] > z[idx]:
|
492 |
+
depth[y, x] = z[idx]
|
493 |
+
return depth
|
494 |
+
|
495 |
+
|
496 |
+
def project_nn(uv, z, depth):
|
497 |
+
"""
|
498 |
+
uv: pixel coordinates [N,2]
|
499 |
+
z: projected depth (xyz -> z) [N]
|
500 |
+
depth: output depth array: [h,w]
|
501 |
+
"""
|
502 |
+
if nn_cython:
|
503 |
+
return project_depth_nn_cython(uv.astype(np.float64),
|
504 |
+
z.astype(np.float64),
|
505 |
+
depth.astype(np.float64))
|
506 |
+
else:
|
507 |
+
return project_depth_nn_python(uv, z, depth)
|
508 |
+
|
509 |
+
|
510 |
+
def warp_depth(depth0, intrinsics0, intrinsics1, pose01, shape1):
|
511 |
+
points3d0, valid0 = unproject_depth(depth0, intrinsics0) # [:,3]
|
512 |
+
points3d0 = points3d0[valid0]
|
513 |
+
|
514 |
+
points3d01 = warp_points3d(points3d0, pose01)
|
515 |
+
|
516 |
+
uv01, z01 = project(points3d01, intrinsics1) # uv: [N,2], (h,w); z1: [N]
|
517 |
+
|
518 |
+
depth01 = project_nn(uv01, z01, depth=np.zeros(shape=shape1))
|
519 |
+
|
520 |
+
return depth01
|
521 |
+
|
522 |
+
|
523 |
+
def warp_points2d(uv0, d0, intrinsics0, intrinsics1, pose01):
|
524 |
+
points3d0 = unproject(uv0, d0, intrinsics0)
|
525 |
+
points3d01 = warp_points3d(points3d0, pose01)
|
526 |
+
uv01, z01 = project(points3d01, intrinsics1)
|
527 |
+
return uv01, z01
|
528 |
+
|
529 |
+
|
530 |
+
def display_image_in_actual_size(image):
|
531 |
+
import matplotlib.pyplot as plt
|
532 |
+
|
533 |
+
dpi = 100
|
534 |
+
height, width = image.shape[:2]
|
535 |
+
|
536 |
+
# What size does the figure need to be in inches to fit the image?
|
537 |
+
figsize = width / float(dpi), height / float(dpi)
|
538 |
+
|
539 |
+
# Create a figure of the right size with one axes that takes up the full figure
|
540 |
+
fig = plt.figure(figsize=figsize)
|
541 |
+
ax = fig.add_axes([0, 0, 1, 1])
|
542 |
+
|
543 |
+
# Hide spines, ticks, etc.
|
544 |
+
ax.axis('off')
|
545 |
+
|
546 |
+
# Display the image.
|
547 |
+
if len(image.shape) == 3:
|
548 |
+
ax.imshow(image, cmap='gray')
|
549 |
+
elif len(image.shape) == 2:
|
550 |
+
if image.dtype == np.uint8:
|
551 |
+
ax.imshow(image, cmap='gray')
|
552 |
+
else:
|
553 |
+
ax.imshow(image)
|
554 |
+
ax.text(20, 20, f"Range: {image.min():g}~{image.max():g}", color='red')
|
555 |
+
plt.show()
|
556 |
+
|
557 |
+
|
558 |
+
# ====================================== copied from ASLFeat
|
559 |
+
from datetime import datetime
|
560 |
+
|
561 |
+
|
562 |
+
class ClassProperty(property):
|
563 |
+
"""For dynamically obtaining system time"""
|
564 |
+
|
565 |
+
def __get__(self, cls, owner):
|
566 |
+
return classmethod(self.fget).__get__(None, owner)()
|
567 |
+
|
568 |
+
|
569 |
+
class Notify(object):
|
570 |
+
"""Colorful printing prefix.
|
571 |
+
A quick example:
|
572 |
+
print(Notify.INFO, YOUR TEXT, Notify.ENDC)
|
573 |
+
"""
|
574 |
+
|
575 |
+
def __init__(self):
|
576 |
+
pass
|
577 |
+
|
578 |
+
@ClassProperty
|
579 |
+
def HEADER(cls):
|
580 |
+
return str(datetime.now()) + ': \033[95m'
|
581 |
+
|
582 |
+
@ClassProperty
|
583 |
+
def INFO(cls):
|
584 |
+
return str(datetime.now()) + ': \033[92mI'
|
585 |
+
|
586 |
+
@ClassProperty
|
587 |
+
def OKBLUE(cls):
|
588 |
+
return str(datetime.now()) + ': \033[94m'
|
589 |
+
|
590 |
+
@ClassProperty
|
591 |
+
def WARNING(cls):
|
592 |
+
return str(datetime.now()) + ': \033[93mW'
|
593 |
+
|
594 |
+
@ClassProperty
|
595 |
+
def FAIL(cls):
|
596 |
+
return str(datetime.now()) + ': \033[91mF'
|
597 |
+
|
598 |
+
@ClassProperty
|
599 |
+
def BOLD(cls):
|
600 |
+
return str(datetime.now()) + ': \033[1mB'
|
601 |
+
|
602 |
+
@ClassProperty
|
603 |
+
def UNDERLINE(cls):
|
604 |
+
return str(datetime.now()) + ': \033[4mU'
|
605 |
+
|
606 |
+
ENDC = '\033[0m'
|
607 |
+
|
608 |
+
def get_essential(T0, T1):
|
609 |
+
R0 = T0[:3, :3]
|
610 |
+
R1 = T1[:3, :3]
|
611 |
+
|
612 |
+
t0 = T0[:3, 3].reshape(3, 1)
|
613 |
+
t1 = T1[:3, 3].reshape(3, 1)
|
614 |
+
|
615 |
+
R0 = torch.tensor(R0, dtype=torch.float32)
|
616 |
+
R1 = torch.tensor(R1, dtype=torch.float32)
|
617 |
+
t0 = torch.tensor(t0, dtype=torch.float32)
|
618 |
+
t1 = torch.tensor(t1, dtype=torch.float32)
|
619 |
+
|
620 |
+
E = essential_from_Rt(R0, t0, R1, t1)
|
621 |
+
|
622 |
+
return E
|
623 |
+
|
624 |
+
def get_fundamental(E, K0, K1):
|
625 |
+
F = fundamental_from_essential(E, K0, K1)
|
626 |
+
|
627 |
+
return F
|
628 |
+
try:
|
629 |
+
# for internel use only
|
630 |
+
from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
|
631 |
+
except Exception:
|
632 |
+
MEGADEPTH_CLIENT = SCANNET_CLIENT = None
|
633 |
+
|
634 |
+
# --- DATA IO ---
|
635 |
+
|
636 |
+
def load_array_from_s3(
|
637 |
+
path, client, cv_type,
|
638 |
+
use_h5py=False,
|
639 |
+
):
|
640 |
+
byte_str = client.Get(path)
|
641 |
+
try:
|
642 |
+
if not use_h5py:
|
643 |
+
raw_array = np.fromstring(byte_str, np.uint8)
|
644 |
+
data = cv2.imdecode(raw_array, cv_type)
|
645 |
+
else:
|
646 |
+
f = io.BytesIO(byte_str)
|
647 |
+
data = np.array(h5py.File(f, 'r')['/depth'])
|
648 |
+
except Exception as ex:
|
649 |
+
print(f"==> Data loading failure: {path}")
|
650 |
+
raise ex
|
651 |
+
|
652 |
+
assert data is not None
|
653 |
+
return data
|
654 |
+
|
655 |
+
|
656 |
+
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
|
657 |
+
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
|
658 |
+
else cv2.IMREAD_COLOR
|
659 |
+
if str(path).startswith('s3://'):
|
660 |
+
image = load_array_from_s3(str(path), client, cv_type)
|
661 |
+
else:
|
662 |
+
image = cv2.imread(str(path), 1)
|
663 |
+
|
664 |
+
if augment_fn is not None:
|
665 |
+
image = cv2.imread(str(path), cv2.IMREAD_COLOR)
|
666 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
667 |
+
image = augment_fn(image)
|
668 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
669 |
+
return image # (h, w)
|
670 |
+
|
671 |
+
|
672 |
+
def get_resized_wh(w, h, resize=None):
|
673 |
+
if resize is not None: # resize the longer edge
|
674 |
+
scale = resize / max(h, w)
|
675 |
+
w_new, h_new = int(round(w*scale)), int(round(h*scale))
|
676 |
+
else:
|
677 |
+
w_new, h_new = w, h
|
678 |
+
return w_new, h_new
|
679 |
+
|
680 |
+
|
681 |
+
def get_divisible_wh(w, h, df=None):
|
682 |
+
if df is not None:
|
683 |
+
w_new, h_new = map(lambda x: int(x // df * df), [w, h])
|
684 |
+
else:
|
685 |
+
w_new, h_new = w, h
|
686 |
+
return w_new, h_new
|
687 |
+
|
688 |
+
|
689 |
+
def pad_bottom_right(inp, pad_size, ret_mask=False):
|
690 |
+
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
|
691 |
+
mask = None
|
692 |
+
if inp.ndim == 2:
|
693 |
+
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
|
694 |
+
padded[:inp.shape[0], :inp.shape[1]] = inp
|
695 |
+
if ret_mask:
|
696 |
+
mask = np.zeros((pad_size, pad_size), dtype=bool)
|
697 |
+
mask[:inp.shape[0], :inp.shape[1]] = True
|
698 |
+
elif inp.ndim == 3:
|
699 |
+
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
|
700 |
+
padded[:, :inp.shape[1], :inp.shape[2]] = inp
|
701 |
+
if ret_mask:
|
702 |
+
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
|
703 |
+
mask[:, :inp.shape[1], :inp.shape[2]] = True
|
704 |
+
else:
|
705 |
+
raise NotImplementedError()
|
706 |
+
return padded, mask
|
707 |
+
|
708 |
+
|
709 |
+
# --- MEGADEPTH ---
|
710 |
+
|
711 |
+
def fix_path_from_d2net(path):
|
712 |
+
if not path:
|
713 |
+
return None
|
714 |
+
|
715 |
+
path = path.replace('Undistorted_SfM/', '')
|
716 |
+
path = path.replace('images', 'dense0/imgs')
|
717 |
+
path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '')
|
718 |
+
|
719 |
+
return path
|
720 |
+
|
721 |
+
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
|
722 |
+
"""
|
723 |
+
Args:
|
724 |
+
resize (int, optional): the longer edge of resized images. None for no resize.
|
725 |
+
padding (bool): If set to 'True', zero-pad resized images to squared size.
|
726 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
727 |
+
Returns:
|
728 |
+
image (torch.tensor): (1, h, w)
|
729 |
+
mask (torch.tensor): (h, w)
|
730 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
731 |
+
"""
|
732 |
+
# read image
|
733 |
+
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
|
734 |
+
|
735 |
+
# resize image
|
736 |
+
w, h = image.shape[1], image.shape[0]
|
737 |
+
|
738 |
+
if len(resize) == 2:
|
739 |
+
w_new, h_new = resize
|
740 |
+
else:
|
741 |
+
resize = resize[0]
|
742 |
+
w_new, h_new = get_resized_wh(w, h, resize)
|
743 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
744 |
+
|
745 |
+
|
746 |
+
image = cv2.resize(image, (w_new, h_new))
|
747 |
+
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
748 |
+
|
749 |
+
if padding: # padding
|
750 |
+
pad_to = max(h_new, w_new)
|
751 |
+
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
|
752 |
+
else:
|
753 |
+
mask = None
|
754 |
+
|
755 |
+
#image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
|
756 |
+
image = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized
|
757 |
+
mask = torch.from_numpy(mask) if mask is not None else None
|
758 |
+
|
759 |
+
return image, mask, scale
|
760 |
+
|
761 |
+
def imread_color(path, augment_fn=None, client=SCANNET_CLIENT):
|
762 |
+
cv_type = cv2.IMREAD_COLOR
|
763 |
+
# if str(path).startswith('s3://'):
|
764 |
+
# image = load_array_from_s3(str(path), client, cv_type)
|
765 |
+
# else:
|
766 |
+
# image = cv2.imread(str(path), cv_type)
|
767 |
+
|
768 |
+
image = cv2.imread(str(path), cv2.IMREAD_COLOR)
|
769 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
770 |
+
if augment_fn is not None:
|
771 |
+
image = augment_fn(image)
|
772 |
+
# image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
773 |
+
return image # (3, h, w)
|
774 |
+
|
775 |
+
|
776 |
+
def read_megadepth_color(path,
|
777 |
+
resize=None,
|
778 |
+
df=None,
|
779 |
+
padding=False,
|
780 |
+
augment_fn=None,
|
781 |
+
rotation=0):
|
782 |
+
"""
|
783 |
+
Args:
|
784 |
+
resize (int, optional): the longer edge of resized images. None for no resize.
|
785 |
+
padding (bool): If set to 'True', zero-pad resized images to squared size.
|
786 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
787 |
+
Returns:
|
788 |
+
image (torch.tensor): (3, h, w)
|
789 |
+
mask (torch.tensor): (h, w)
|
790 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
791 |
+
"""
|
792 |
+
# read image
|
793 |
+
image = imread_color(path, augment_fn, client=MEGADEPTH_CLIENT)
|
794 |
+
|
795 |
+
if rotation != 0:
|
796 |
+
image = np.rot90(image, k=rotation).copy()
|
797 |
+
|
798 |
+
# resize image
|
799 |
+
if resize is not None:
|
800 |
+
w, h = image.shape[1], image.shape[0]
|
801 |
+
if len(resize) == 2:
|
802 |
+
w_new, h_new = resize
|
803 |
+
else:
|
804 |
+
resize = resize[0]
|
805 |
+
w_new, h_new = get_resized_wh(w, h, resize)
|
806 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
807 |
+
|
808 |
+
|
809 |
+
image = cv2.resize(image, (w_new, h_new))
|
810 |
+
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
811 |
+
scale_wh = torch.tensor([w_new, h_new], dtype=torch.float)
|
812 |
+
else:
|
813 |
+
scale = torch.tensor([1., 1.], dtype=torch.float)
|
814 |
+
scale_wh = torch.tensor([image.shape[1], image.shape[0]], dtype=torch.float)
|
815 |
+
|
816 |
+
image = image.transpose(2, 0, 1)
|
817 |
+
|
818 |
+
if padding: # padding
|
819 |
+
if resize is not None:
|
820 |
+
pad_to = max(h_new, w_new)
|
821 |
+
else:
|
822 |
+
pad_to = 2000
|
823 |
+
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
|
824 |
+
else:
|
825 |
+
mask = None
|
826 |
+
|
827 |
+
image = torch.from_numpy(image).float() / 255 # (h, w) -> (1, h, w) and normalized
|
828 |
+
mask = torch.from_numpy(mask) if mask is not None else None
|
829 |
+
|
830 |
+
return image, mask, scale
|
831 |
+
|
832 |
+
def read_megadepth_depth(path, pad_to=None):
|
833 |
+
|
834 |
+
if str(path).startswith('s3://'):
|
835 |
+
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
|
836 |
+
else:
|
837 |
+
depth = np.array(h5py.File(path, 'r')['depth'])
|
838 |
+
if pad_to is not None:
|
839 |
+
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
|
840 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
841 |
+
return depth
|
842 |
+
|
843 |
+
def get_image_name(path):
|
844 |
+
return path.split('/')[-1]
|
845 |
+
|
846 |
+
def scale_intrinsics(K, scales):
|
847 |
+
scales = np.diag([1. / scales[0], 1. / scales[1], 1.])
|
848 |
+
return np.dot(scales, K)
|
imcui/third_party/rdd/RDD/matchers/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dual_softmax_matcher import DualSoftmaxMatcher
|
2 |
+
from .dense_matcher import DenseMatcher
|
3 |
+
from .lightglue import LightGlue
|
imcui/third_party/rdd/RDD/matchers/dense_matcher.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import poselib
|
5 |
+
|
6 |
+
class DenseMatcher(nn.Module):
|
7 |
+
def __init__(self, inv_temperature = 20, thr = 0.01):
|
8 |
+
super().__init__()
|
9 |
+
self.inv_temperature = inv_temperature
|
10 |
+
self.thr = thr
|
11 |
+
|
12 |
+
def forward(self, info0, info1, thr = None, err_thr=4, min_num_inliers=30):
|
13 |
+
|
14 |
+
desc0 = info0['descriptors']
|
15 |
+
desc1 = info1['descriptors']
|
16 |
+
|
17 |
+
inds, P = self.dual_softmax(desc0, desc1, thr=thr)
|
18 |
+
|
19 |
+
mkpts_0 = info0['keypoints'][inds[:,0]]
|
20 |
+
mkpts_1 = info1['keypoints'][inds[:,1]]
|
21 |
+
mconf = P[inds[:,0], inds[:,1]]
|
22 |
+
Fm, inliers = self.get_fundamental_matrix(mkpts_0, mkpts_1)
|
23 |
+
|
24 |
+
if inliers.sum() >= min_num_inliers:
|
25 |
+
desc1_dense = info0['descriptors_dense']
|
26 |
+
desc2_dense = info1['descriptors_dense']
|
27 |
+
|
28 |
+
inds_dense, P_dense = self.dual_softmax(desc1_dense, desc2_dense, thr=thr)
|
29 |
+
|
30 |
+
mkpts_0_dense = info0['keypoints_dense'][inds_dense[:,0]]
|
31 |
+
mkpts_1_dense = info1['keypoints_dense'][inds_dense[:,1]]
|
32 |
+
mconf_dense = P_dense[inds_dense[:,0], inds_dense[:,1]]
|
33 |
+
|
34 |
+
mkpts_0_dense, mkpts_1_dense, mconf_dense = self.refine_matches(mkpts_0_dense, mkpts_1_dense, mconf_dense, Fm, err_thr=err_thr)
|
35 |
+
mkpts_0 = mkpts_0[inliers]
|
36 |
+
mkpts_1 = mkpts_1[inliers]
|
37 |
+
mconf = mconf[inliers]
|
38 |
+
# concatenate the matches
|
39 |
+
mkpts_0 = torch.cat([mkpts_0, mkpts_0_dense], dim=0)
|
40 |
+
mkpts_1 = torch.cat([mkpts_1, mkpts_1_dense], dim=0)
|
41 |
+
mconf = torch.cat([mconf, mconf_dense], dim=0)
|
42 |
+
|
43 |
+
return mkpts_0, mkpts_1, mconf
|
44 |
+
|
45 |
+
def get_fundamental_matrix(self, kpts_0, kpts_1):
|
46 |
+
Fm, info = poselib.estimate_fundamental(kpts_0.cpu().numpy(), kpts_1.cpu().numpy(), {'max_epipolar_error': 1, 'progressive_sampling': True}, {})
|
47 |
+
inliers = info['inliers']
|
48 |
+
Fm = torch.tensor(Fm, device=kpts_0.device, dtype=kpts_0.dtype)
|
49 |
+
inliers = torch.tensor(inliers, device=kpts_0.device, dtype=torch.bool)
|
50 |
+
return Fm, inliers
|
51 |
+
|
52 |
+
def dual_softmax(self, desc0, desc1, thr = None):
|
53 |
+
if thr is None:
|
54 |
+
thr = self.thr
|
55 |
+
dist_mat = (desc0 @ desc1.t()) * self.inv_temperature
|
56 |
+
P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1)
|
57 |
+
inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values)
|
58 |
+
* (P == P.max(dim=-2, keepdim = True).values) * (P >= thr))
|
59 |
+
|
60 |
+
return inds, P
|
61 |
+
|
62 |
+
@torch.inference_mode()
|
63 |
+
def refine_matches(self, mkpts_0, mkpts_1, mconf, Fm, err_thr=4):
|
64 |
+
mkpts_0_h = torch.cat([mkpts_0, torch.ones(mkpts_0.shape[0], 1, device=mkpts_0.device)], dim=1) # (N, 3)
|
65 |
+
mkpts_1_h = torch.cat([mkpts_1, torch.ones(mkpts_1.shape[0], 1, device=mkpts_1.device)], dim=1) # (N, 3)
|
66 |
+
|
67 |
+
lines_1 = torch.matmul(Fm, mkpts_0_h.T).T
|
68 |
+
|
69 |
+
a, b, c = lines_1[:, 0], lines_1[:, 1], lines_1[:, 2]
|
70 |
+
|
71 |
+
x1, y1 = mkpts_1[:, 0], mkpts_1[:, 1]
|
72 |
+
|
73 |
+
denom = a**2 + b**2 + 1e-8
|
74 |
+
|
75 |
+
x_offset = (b * (b * x1 - a * y1) - a * c) / denom - x1
|
76 |
+
y_offset = (a * (a * y1 - b * x1) - b * c) / denom - y1
|
77 |
+
|
78 |
+
inds = (x_offset.abs() < err_thr) | (y_offset.abs() < err_thr)
|
79 |
+
|
80 |
+
x_offset = x_offset[inds]
|
81 |
+
y_offset = y_offset[inds]
|
82 |
+
|
83 |
+
mkpts_0 = mkpts_0[inds]
|
84 |
+
mkpts_1 = mkpts_1[inds]
|
85 |
+
|
86 |
+
refined_mkpts_1 = mkpts_1 + torch.stack([x_offset, y_offset], dim=1)
|
87 |
+
|
88 |
+
return mkpts_0, refined_mkpts_1, mconf[inds]
|
imcui/third_party/rdd/RDD/matchers/dual_softmax_matcher.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class DualSoftmaxMatcher(nn.Module):
|
6 |
+
def __init__(self, inv_temperature = 20, thr = 0.01):
|
7 |
+
super().__init__()
|
8 |
+
self.inv_temperature = inv_temperature
|
9 |
+
self.thr = thr
|
10 |
+
|
11 |
+
def forward(self, info0, info1, thr = None):
|
12 |
+
desc0 = info0['descriptors']
|
13 |
+
desc1 = info1['descriptors']
|
14 |
+
|
15 |
+
inds, P = self.dual_softmax(desc0, desc1, thr)
|
16 |
+
mkpts0 = info0['keypoints'][inds[:,0]]
|
17 |
+
mkpts1 = info1['keypoints'][inds[:,1]]
|
18 |
+
mconf = P[inds[:,0], inds[:,1]]
|
19 |
+
|
20 |
+
return mkpts0, mkpts1, mconf
|
21 |
+
|
22 |
+
def dual_softmax(self, desc0, desc1, thr = None):
|
23 |
+
if thr is None:
|
24 |
+
thr = self.thr
|
25 |
+
dist_mat = (desc0 @ desc1.t()) * self.inv_temperature
|
26 |
+
P = dist_mat.softmax(dim = -2) * dist_mat.softmax(dim= -1)
|
27 |
+
|
28 |
+
inds = torch.nonzero((P == P.max(dim=-1, keepdim = True).values)
|
29 |
+
* (P == P.max(dim=-2, keepdim = True).values) * (P >= thr))
|
30 |
+
|
31 |
+
return inds, P
|
imcui/third_party/rdd/RDD/matchers/lightglue.py
ADDED
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modified from
|
3 |
+
https://github.com/cvg/LightGlue
|
4 |
+
"""
|
5 |
+
|
6 |
+
import warnings
|
7 |
+
from pathlib import Path
|
8 |
+
from types import SimpleNamespace
|
9 |
+
from typing import Callable, List, Optional, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
try:
|
17 |
+
from flash_attn.modules.mha import FlashCrossAttention
|
18 |
+
except ModuleNotFoundError:
|
19 |
+
FlashCrossAttention = None
|
20 |
+
|
21 |
+
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
|
22 |
+
FLASH_AVAILABLE = True
|
23 |
+
else:
|
24 |
+
FLASH_AVAILABLE = False
|
25 |
+
|
26 |
+
torch.backends.cudnn.deterministic = True
|
27 |
+
|
28 |
+
|
29 |
+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
30 |
+
def normalize_keypoints(
|
31 |
+
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
32 |
+
) -> torch.Tensor:
|
33 |
+
if size is None:
|
34 |
+
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
35 |
+
elif not isinstance(size, torch.Tensor):
|
36 |
+
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
37 |
+
size = size.to(kpts)
|
38 |
+
shift = size / 2
|
39 |
+
scale = size.max(-1).values / 2
|
40 |
+
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
41 |
+
return kpts
|
42 |
+
|
43 |
+
|
44 |
+
def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
45 |
+
if length <= x.shape[-2]:
|
46 |
+
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
47 |
+
pad = torch.ones(
|
48 |
+
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
49 |
+
)
|
50 |
+
y = torch.cat([x, pad], dim=-2)
|
51 |
+
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
52 |
+
mask[..., : x.shape[-2], :] = True
|
53 |
+
return y, mask
|
54 |
+
|
55 |
+
|
56 |
+
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
57 |
+
x = x.unflatten(-1, (-1, 2))
|
58 |
+
x1, x2 = x.unbind(dim=-1)
|
59 |
+
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
60 |
+
|
61 |
+
|
62 |
+
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
63 |
+
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
|
64 |
+
|
65 |
+
|
66 |
+
class LearnableFourierPositionalEncoding(nn.Module):
|
67 |
+
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
|
68 |
+
super().__init__()
|
69 |
+
F_dim = F_dim if F_dim is not None else dim
|
70 |
+
self.gamma = gamma
|
71 |
+
self.Wr = nn.Linear(M, F_dim // 2, bias=False)
|
72 |
+
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
|
73 |
+
|
74 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
75 |
+
"""encode position vector"""
|
76 |
+
projected = self.Wr(x)
|
77 |
+
cosines, sines = torch.cos(projected), torch.sin(projected)
|
78 |
+
emb = torch.stack([cosines, sines], 0).unsqueeze(-3)
|
79 |
+
return emb.repeat_interleave(2, dim=-1)
|
80 |
+
|
81 |
+
|
82 |
+
class TokenConfidence(nn.Module):
|
83 |
+
def __init__(self, dim: int) -> None:
|
84 |
+
super().__init__()
|
85 |
+
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
86 |
+
|
87 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
88 |
+
"""get confidence tokens"""
|
89 |
+
return (
|
90 |
+
self.token(desc0.detach()).squeeze(-1),
|
91 |
+
self.token(desc1.detach()).squeeze(-1),
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
class Attention(nn.Module):
|
96 |
+
def __init__(self, allow_flash: bool) -> None:
|
97 |
+
super().__init__()
|
98 |
+
if allow_flash and not FLASH_AVAILABLE:
|
99 |
+
warnings.warn(
|
100 |
+
"FlashAttention is not available. For optimal speed, "
|
101 |
+
"consider installing torch >= 2.0 or flash-attn.",
|
102 |
+
stacklevel=2,
|
103 |
+
)
|
104 |
+
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
105 |
+
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
106 |
+
if allow_flash and FlashCrossAttention:
|
107 |
+
self.flash_ = FlashCrossAttention()
|
108 |
+
if self.has_sdp:
|
109 |
+
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
110 |
+
|
111 |
+
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
112 |
+
if q.shape[-2] == 0 or k.shape[-2] == 0:
|
113 |
+
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
|
114 |
+
if self.enable_flash and q.device.type == "cuda":
|
115 |
+
# use torch 2.0 scaled_dot_product_attention with flash
|
116 |
+
if self.has_sdp:
|
117 |
+
args = [x.half().contiguous() for x in [q, k, v]]
|
118 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
119 |
+
return v if mask is None else v.nan_to_num()
|
120 |
+
else:
|
121 |
+
assert mask is None
|
122 |
+
q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
|
123 |
+
m = self.flash_(q.half(), torch.stack([k, v], 2).half())
|
124 |
+
return m.transpose(-2, -3).to(q.dtype).clone()
|
125 |
+
elif self.has_sdp:
|
126 |
+
args = [x.contiguous() for x in [q, k, v]]
|
127 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
128 |
+
return v if mask is None else v.nan_to_num()
|
129 |
+
else:
|
130 |
+
s = q.shape[-1] ** -0.5
|
131 |
+
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
132 |
+
if mask is not None:
|
133 |
+
sim.masked_fill(~mask, -float("inf"))
|
134 |
+
attn = F.softmax(sim, -1)
|
135 |
+
return torch.einsum("...ij,...jd->...id", attn, v)
|
136 |
+
|
137 |
+
|
138 |
+
class SelfBlock(nn.Module):
|
139 |
+
def __init__(
|
140 |
+
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
141 |
+
) -> None:
|
142 |
+
super().__init__()
|
143 |
+
self.embed_dim = embed_dim
|
144 |
+
self.num_heads = num_heads
|
145 |
+
assert self.embed_dim % num_heads == 0
|
146 |
+
self.head_dim = self.embed_dim // num_heads
|
147 |
+
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
148 |
+
self.inner_attn = Attention(flash)
|
149 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
150 |
+
self.ffn = nn.Sequential(
|
151 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
152 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
153 |
+
nn.GELU(),
|
154 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
155 |
+
)
|
156 |
+
|
157 |
+
def forward(
|
158 |
+
self,
|
159 |
+
x: torch.Tensor,
|
160 |
+
encoding: torch.Tensor,
|
161 |
+
mask: Optional[torch.Tensor] = None,
|
162 |
+
) -> torch.Tensor:
|
163 |
+
qkv = self.Wqkv(x)
|
164 |
+
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
165 |
+
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
166 |
+
q = apply_cached_rotary_emb(encoding, q)
|
167 |
+
k = apply_cached_rotary_emb(encoding, k)
|
168 |
+
context = self.inner_attn(q, k, v, mask=mask)
|
169 |
+
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
170 |
+
return x + self.ffn(torch.cat([x, message], -1))
|
171 |
+
|
172 |
+
|
173 |
+
class CrossBlock(nn.Module):
|
174 |
+
def __init__(
|
175 |
+
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
176 |
+
) -> None:
|
177 |
+
super().__init__()
|
178 |
+
self.heads = num_heads
|
179 |
+
dim_head = embed_dim // num_heads
|
180 |
+
self.scale = dim_head**-0.5
|
181 |
+
inner_dim = dim_head * num_heads
|
182 |
+
self.to_qk = nn.Linear(embed_dim, inner_dim, bias=bias)
|
183 |
+
self.to_v = nn.Linear(embed_dim, inner_dim, bias=bias)
|
184 |
+
self.to_out = nn.Linear(inner_dim, embed_dim, bias=bias)
|
185 |
+
self.ffn = nn.Sequential(
|
186 |
+
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
187 |
+
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
188 |
+
nn.GELU(),
|
189 |
+
nn.Linear(2 * embed_dim, embed_dim),
|
190 |
+
)
|
191 |
+
if flash and FLASH_AVAILABLE:
|
192 |
+
self.flash = Attention(True)
|
193 |
+
else:
|
194 |
+
self.flash = None
|
195 |
+
|
196 |
+
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
197 |
+
return func(x0), func(x1)
|
198 |
+
|
199 |
+
def forward(
|
200 |
+
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
201 |
+
) -> List[torch.Tensor]:
|
202 |
+
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
203 |
+
v0, v1 = self.map_(self.to_v, x0, x1)
|
204 |
+
qk0, qk1, v0, v1 = map(
|
205 |
+
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
206 |
+
(qk0, qk1, v0, v1),
|
207 |
+
)
|
208 |
+
if self.flash is not None and qk0.device.type == "cuda":
|
209 |
+
m0 = self.flash(qk0, qk1, v1, mask)
|
210 |
+
m1 = self.flash(
|
211 |
+
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
212 |
+
)
|
213 |
+
else:
|
214 |
+
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
215 |
+
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
216 |
+
if mask is not None:
|
217 |
+
sim = sim.masked_fill(~mask, -float("inf"))
|
218 |
+
attn01 = F.softmax(sim, dim=-1)
|
219 |
+
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
220 |
+
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
221 |
+
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
222 |
+
if mask is not None:
|
223 |
+
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
224 |
+
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
225 |
+
m0, m1 = self.map_(self.to_out, m0, m1)
|
226 |
+
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
227 |
+
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
228 |
+
return x0, x1
|
229 |
+
|
230 |
+
|
231 |
+
class TransformerLayer(nn.Module):
|
232 |
+
def __init__(self, *args, **kwargs):
|
233 |
+
super().__init__()
|
234 |
+
self.self_attn = SelfBlock(*args, **kwargs)
|
235 |
+
self.cross_attn = CrossBlock(*args, **kwargs)
|
236 |
+
|
237 |
+
def forward(
|
238 |
+
self,
|
239 |
+
desc0,
|
240 |
+
desc1,
|
241 |
+
encoding0,
|
242 |
+
encoding1,
|
243 |
+
mask0: Optional[torch.Tensor] = None,
|
244 |
+
mask1: Optional[torch.Tensor] = None,
|
245 |
+
):
|
246 |
+
if mask0 is not None and mask1 is not None:
|
247 |
+
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
248 |
+
else:
|
249 |
+
desc0 = self.self_attn(desc0, encoding0)
|
250 |
+
desc1 = self.self_attn(desc1, encoding1)
|
251 |
+
return self.cross_attn(desc0, desc1)
|
252 |
+
|
253 |
+
# This part is compiled and allows padding inputs
|
254 |
+
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
|
255 |
+
mask = mask0 & mask1.transpose(-1, -2)
|
256 |
+
mask0 = mask0 & mask0.transpose(-1, -2)
|
257 |
+
mask1 = mask1 & mask1.transpose(-1, -2)
|
258 |
+
desc0 = self.self_attn(desc0, encoding0, mask0)
|
259 |
+
desc1 = self.self_attn(desc1, encoding1, mask1)
|
260 |
+
return self.cross_attn(desc0, desc1, mask)
|
261 |
+
|
262 |
+
|
263 |
+
def sigmoid_log_double_softmax(
|
264 |
+
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
265 |
+
) -> torch.Tensor:
|
266 |
+
"""create the log assignment matrix from logits and similarity"""
|
267 |
+
b, m, n = sim.shape
|
268 |
+
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
269 |
+
scores0 = F.log_softmax(sim, 2)
|
270 |
+
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
271 |
+
scores = sim.new_full((b, m + 1, n + 1), 0)
|
272 |
+
scores[:, :m, :n] = scores0 + scores1 + certainties
|
273 |
+
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
274 |
+
scores[:, -1, :-1] = F.logsigmoid(-z1.squeeze(-1))
|
275 |
+
return scores
|
276 |
+
|
277 |
+
|
278 |
+
class MatchAssignment(nn.Module):
|
279 |
+
def __init__(self, dim: int) -> None:
|
280 |
+
super().__init__()
|
281 |
+
self.dim = dim
|
282 |
+
self.matchability = nn.Linear(dim, 1, bias=True)
|
283 |
+
self.final_proj = nn.Linear(dim, dim, bias=True)
|
284 |
+
|
285 |
+
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
286 |
+
"""build assignment matrix from descriptors"""
|
287 |
+
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
288 |
+
_, _, d = mdesc0.shape
|
289 |
+
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
|
290 |
+
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
|
291 |
+
z0 = self.matchability(desc0)
|
292 |
+
z1 = self.matchability(desc1)
|
293 |
+
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
294 |
+
return scores, sim
|
295 |
+
|
296 |
+
def get_matchability(self, desc: torch.Tensor):
|
297 |
+
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
298 |
+
|
299 |
+
|
300 |
+
def filter_matches(scores: torch.Tensor, th: float):
|
301 |
+
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
302 |
+
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
303 |
+
m0, m1 = max0.indices, max1.indices
|
304 |
+
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
305 |
+
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
306 |
+
mutual0 = indices0 == m1.gather(1, m0)
|
307 |
+
mutual1 = indices1 == m0.gather(1, m1)
|
308 |
+
max0_exp = max0.values.exp()
|
309 |
+
zero = max0_exp.new_tensor(0)
|
310 |
+
mscores0 = torch.where(mutual0, max0_exp, zero)
|
311 |
+
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
312 |
+
valid0 = mutual0 & (mscores0 > th)
|
313 |
+
valid1 = mutual1 & valid0.gather(1, m1)
|
314 |
+
m0 = torch.where(valid0, m0, -1)
|
315 |
+
m1 = torch.where(valid1, m1, -1)
|
316 |
+
return m0, m1, mscores0, mscores1
|
317 |
+
|
318 |
+
|
319 |
+
class LightGlue(nn.Module):
|
320 |
+
default_conf = {
|
321 |
+
"name": "lightglue", # just for interfacing
|
322 |
+
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
323 |
+
"descriptor_dim": 256,
|
324 |
+
"add_scale_ori": False,
|
325 |
+
"n_layers": 9,
|
326 |
+
"num_heads": 4,
|
327 |
+
"flash": True, # enable FlashAttention if available.
|
328 |
+
"mp": False, # enable mixed precision
|
329 |
+
"depth_confidence": -1, # early stopping, disable with -1
|
330 |
+
"width_confidence": -1, # point pruning, disable with -1
|
331 |
+
"filter_threshold": 0.01, # match threshold
|
332 |
+
"weights": None,
|
333 |
+
}
|
334 |
+
|
335 |
+
# Point pruning involves an overhead (gather).
|
336 |
+
# Therefore, we only activate it if there are enough keypoints.
|
337 |
+
pruning_keypoint_thresholds = {
|
338 |
+
"cpu": -1,
|
339 |
+
"mps": -1,
|
340 |
+
"cuda": 1024,
|
341 |
+
"flash": 1536,
|
342 |
+
}
|
343 |
+
|
344 |
+
required_data_keys = ["image0", "image1"]
|
345 |
+
|
346 |
+
version = "v0.1_arxiv"
|
347 |
+
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
348 |
+
|
349 |
+
features = {
|
350 |
+
"superpoint": {
|
351 |
+
"weights": "superpoint_lightglue",
|
352 |
+
"input_dim": 256,
|
353 |
+
},
|
354 |
+
"disk": {
|
355 |
+
"weights": "disk_lightglue",
|
356 |
+
"input_dim": 128,
|
357 |
+
},
|
358 |
+
"aliked": {
|
359 |
+
"weights": "aliked_lightglue",
|
360 |
+
"input_dim": 128,
|
361 |
+
},
|
362 |
+
"sift": {
|
363 |
+
"weights": "sift_lightglue",
|
364 |
+
"input_dim": 128,
|
365 |
+
"add_scale_ori": True,
|
366 |
+
},
|
367 |
+
"doghardnet": {
|
368 |
+
"weights": "doghardnet_lightglue",
|
369 |
+
"input_dim": 128,
|
370 |
+
"add_scale_ori": True,
|
371 |
+
},
|
372 |
+
"rdd": {
|
373 |
+
"weights": './weights/RDD_lg-v2.pth',
|
374 |
+
"input_dim": 256,
|
375 |
+
},
|
376 |
+
}
|
377 |
+
|
378 |
+
def __init__(self, features="rdd", **conf) -> None:
|
379 |
+
super().__init__()
|
380 |
+
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
381 |
+
if features is not None:
|
382 |
+
if features not in self.features:
|
383 |
+
raise ValueError(
|
384 |
+
f"Unsupported features: {features} not in "
|
385 |
+
f"{{{','.join(self.features)}}}"
|
386 |
+
)
|
387 |
+
for k, v in self.features[features].items():
|
388 |
+
setattr(conf, k, v)
|
389 |
+
|
390 |
+
if conf.input_dim != conf.descriptor_dim:
|
391 |
+
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
392 |
+
else:
|
393 |
+
self.input_proj = nn.Identity()
|
394 |
+
|
395 |
+
head_dim = conf.descriptor_dim // conf.num_heads
|
396 |
+
self.posenc = LearnableFourierPositionalEncoding(
|
397 |
+
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
|
398 |
+
)
|
399 |
+
|
400 |
+
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
401 |
+
|
402 |
+
self.transformers = nn.ModuleList(
|
403 |
+
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
404 |
+
)
|
405 |
+
|
406 |
+
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
407 |
+
self.token_confidence = nn.ModuleList(
|
408 |
+
[TokenConfidence(d) for _ in range(n - 1)]
|
409 |
+
)
|
410 |
+
self.register_buffer(
|
411 |
+
"confidence_thresholds",
|
412 |
+
torch.Tensor(
|
413 |
+
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
|
414 |
+
),
|
415 |
+
)
|
416 |
+
|
417 |
+
state_dict = None
|
418 |
+
if features is not None and features != 'rdd':
|
419 |
+
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
|
420 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
421 |
+
self.url.format(self.version, features), file_name=fname
|
422 |
+
)
|
423 |
+
self.load_state_dict(state_dict, strict=False)
|
424 |
+
elif conf.weights is not None:
|
425 |
+
if features == 'rdd':
|
426 |
+
path = Path(conf.weights)
|
427 |
+
else:
|
428 |
+
path = Path(__file__).parent
|
429 |
+
path = path / "weights/{}.pth".format(self.conf.weights)
|
430 |
+
state_dict = torch.load(str(path), map_location="cpu")
|
431 |
+
|
432 |
+
if state_dict:
|
433 |
+
# rename old state dict entries
|
434 |
+
for i in range(self.conf.n_layers):
|
435 |
+
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
436 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
437 |
+
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
438 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
439 |
+
self.load_state_dict(state_dict, strict=False)
|
440 |
+
|
441 |
+
# static lengths LightGlue is compiled for (only used with torch.compile)
|
442 |
+
self.static_lengths = None
|
443 |
+
|
444 |
+
def compile(
|
445 |
+
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
446 |
+
):
|
447 |
+
if self.conf.width_confidence != -1:
|
448 |
+
warnings.warn(
|
449 |
+
"Point pruning is partially disabled for compiled forward.",
|
450 |
+
stacklevel=2,
|
451 |
+
)
|
452 |
+
|
453 |
+
torch._inductor.cudagraph_mark_step_begin()
|
454 |
+
for i in range(self.conf.n_layers):
|
455 |
+
self.transformers[i].masked_forward = torch.compile(
|
456 |
+
self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
457 |
+
)
|
458 |
+
|
459 |
+
self.static_lengths = static_lengths
|
460 |
+
|
461 |
+
def forward(self, data: dict) -> dict:
|
462 |
+
"""
|
463 |
+
Match keypoints and descriptors between two images
|
464 |
+
|
465 |
+
Input (dict):
|
466 |
+
image0: dict
|
467 |
+
keypoints: [B x M x 2]
|
468 |
+
descriptors: [B x M x D]
|
469 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
470 |
+
image1: dict
|
471 |
+
keypoints: [B x N x 2]
|
472 |
+
descriptors: [B x N x D]
|
473 |
+
image: [B x C x H x W] or image_size: [B x 2]
|
474 |
+
Output (dict):
|
475 |
+
matches0: [B x M]
|
476 |
+
matching_scores0: [B x M]
|
477 |
+
matches1: [B x N]
|
478 |
+
matching_scores1: [B x N]
|
479 |
+
matches: List[[Si x 2]]
|
480 |
+
scores: List[[Si]]
|
481 |
+
stop: int
|
482 |
+
prune0: [B x M]
|
483 |
+
prune1: [B x N]
|
484 |
+
"""
|
485 |
+
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
486 |
+
return self._forward(data)
|
487 |
+
|
488 |
+
def _forward(self, data: dict) -> dict:
|
489 |
+
for key in self.required_data_keys:
|
490 |
+
assert key in data, f"Missing key {key} in data"
|
491 |
+
data0, data1 = data["image0"], data["image1"]
|
492 |
+
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
493 |
+
b, m, _ = kpts0.shape
|
494 |
+
b, n, _ = kpts1.shape
|
495 |
+
device = kpts0.device
|
496 |
+
size0, size1 = data0.get("image_size"), data1.get("image_size")
|
497 |
+
kpts0 = normalize_keypoints(kpts0, size0).clone()
|
498 |
+
kpts1 = normalize_keypoints(kpts1, size1).clone()
|
499 |
+
|
500 |
+
if self.conf.add_scale_ori:
|
501 |
+
kpts0 = torch.cat(
|
502 |
+
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
503 |
+
)
|
504 |
+
kpts1 = torch.cat(
|
505 |
+
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
506 |
+
)
|
507 |
+
desc0 = data0["descriptors"].detach().contiguous()
|
508 |
+
desc1 = data1["descriptors"].detach().contiguous()
|
509 |
+
|
510 |
+
assert desc0.shape[-1] == self.conf.input_dim
|
511 |
+
assert desc1.shape[-1] == self.conf.input_dim
|
512 |
+
|
513 |
+
if torch.is_autocast_enabled():
|
514 |
+
desc0 = desc0.half()
|
515 |
+
desc1 = desc1.half()
|
516 |
+
|
517 |
+
mask0, mask1 = None, None
|
518 |
+
c = max(m, n)
|
519 |
+
do_compile = self.static_lengths and c <= max(self.static_lengths)
|
520 |
+
if do_compile:
|
521 |
+
kn = min([k for k in self.static_lengths if k >= c])
|
522 |
+
desc0, mask0 = pad_to_length(desc0, kn)
|
523 |
+
desc1, mask1 = pad_to_length(desc1, kn)
|
524 |
+
kpts0, _ = pad_to_length(kpts0, kn)
|
525 |
+
kpts1, _ = pad_to_length(kpts1, kn)
|
526 |
+
desc0 = self.input_proj(desc0)
|
527 |
+
desc1 = self.input_proj(desc1)
|
528 |
+
# cache positional embeddings
|
529 |
+
encoding0 = self.posenc(kpts0)
|
530 |
+
encoding1 = self.posenc(kpts1)
|
531 |
+
|
532 |
+
# GNN + final_proj + assignment
|
533 |
+
do_early_stop = self.conf.depth_confidence > 0
|
534 |
+
do_point_pruning = self.conf.width_confidence > 0 and not do_compile
|
535 |
+
pruning_th = self.pruning_min_kpts(device)
|
536 |
+
if do_point_pruning:
|
537 |
+
ind0 = torch.arange(0, m, device=device)[None]
|
538 |
+
ind1 = torch.arange(0, n, device=device)[None]
|
539 |
+
# We store the index of the layer at which pruning is detected.
|
540 |
+
prune0 = torch.ones_like(ind0)
|
541 |
+
prune1 = torch.ones_like(ind1)
|
542 |
+
token0, token1 = None, None
|
543 |
+
for i in range(self.conf.n_layers):
|
544 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
545 |
+
break
|
546 |
+
desc0, desc1 = self.transformers[i](
|
547 |
+
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
|
548 |
+
)
|
549 |
+
if i == self.conf.n_layers - 1:
|
550 |
+
continue # no early stopping or adaptive width at last layer
|
551 |
+
|
552 |
+
if do_early_stop:
|
553 |
+
token0, token1 = self.token_confidence[i](desc0, desc1)
|
554 |
+
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
|
555 |
+
break
|
556 |
+
if do_point_pruning and desc0.shape[-2] > pruning_th:
|
557 |
+
scores0 = self.log_assignment[i].get_matchability(desc0)
|
558 |
+
prunemask0 = self.get_pruning_mask(token0, scores0, i)
|
559 |
+
keep0 = torch.where(prunemask0)[1]
|
560 |
+
ind0 = ind0.index_select(1, keep0)
|
561 |
+
desc0 = desc0.index_select(1, keep0)
|
562 |
+
encoding0 = encoding0.index_select(-2, keep0)
|
563 |
+
prune0[:, ind0] += 1
|
564 |
+
if do_point_pruning and desc1.shape[-2] > pruning_th:
|
565 |
+
scores1 = self.log_assignment[i].get_matchability(desc1)
|
566 |
+
prunemask1 = self.get_pruning_mask(token1, scores1, i)
|
567 |
+
keep1 = torch.where(prunemask1)[1]
|
568 |
+
ind1 = ind1.index_select(1, keep1)
|
569 |
+
desc1 = desc1.index_select(1, keep1)
|
570 |
+
encoding1 = encoding1.index_select(-2, keep1)
|
571 |
+
prune1[:, ind1] += 1
|
572 |
+
|
573 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
574 |
+
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
|
575 |
+
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
|
576 |
+
mscores0 = desc0.new_zeros((b, m))
|
577 |
+
mscores1 = desc1.new_zeros((b, n))
|
578 |
+
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
|
579 |
+
mscores = desc0.new_empty((b, 0))
|
580 |
+
if not do_point_pruning:
|
581 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
582 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
583 |
+
return {
|
584 |
+
"matches0": m0,
|
585 |
+
"matches1": m1,
|
586 |
+
"matching_scores0": mscores0,
|
587 |
+
"matching_scores1": mscores1,
|
588 |
+
"stop": i + 1,
|
589 |
+
"matches": matches,
|
590 |
+
"scores": mscores,
|
591 |
+
"prune0": prune0,
|
592 |
+
"prune1": prune1,
|
593 |
+
}
|
594 |
+
|
595 |
+
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
|
596 |
+
scores, _ = self.log_assignment[i](desc0, desc1)
|
597 |
+
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
598 |
+
matches, mscores = [], []
|
599 |
+
for k in range(b):
|
600 |
+
valid = m0[k] > -1
|
601 |
+
m_indices_0 = torch.where(valid)[0]
|
602 |
+
m_indices_1 = m0[k][valid]
|
603 |
+
if do_point_pruning:
|
604 |
+
m_indices_0 = ind0[k, m_indices_0]
|
605 |
+
m_indices_1 = ind1[k, m_indices_1]
|
606 |
+
matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
607 |
+
mscores.append(mscores0[k][valid])
|
608 |
+
|
609 |
+
# TODO: Remove when hloc switches to the compact format.
|
610 |
+
if do_point_pruning:
|
611 |
+
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
612 |
+
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
613 |
+
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
614 |
+
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
615 |
+
mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
616 |
+
mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
617 |
+
mscores0_[:, ind0] = mscores0
|
618 |
+
mscores1_[:, ind1] = mscores1
|
619 |
+
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
620 |
+
else:
|
621 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
622 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
623 |
+
|
624 |
+
return {
|
625 |
+
"matches0": m0,
|
626 |
+
"matches1": m1,
|
627 |
+
"matching_scores0": mscores0,
|
628 |
+
"matching_scores1": mscores1,
|
629 |
+
"stop": i + 1,
|
630 |
+
"matches": matches,
|
631 |
+
"scores": mscores,
|
632 |
+
"prune0": prune0,
|
633 |
+
"prune1": prune1,
|
634 |
+
}
|
635 |
+
|
636 |
+
def confidence_threshold(self, layer_index: int) -> float:
|
637 |
+
"""scaled confidence threshold"""
|
638 |
+
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
639 |
+
return np.clip(threshold, 0, 1)
|
640 |
+
|
641 |
+
def get_pruning_mask(
|
642 |
+
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
643 |
+
) -> torch.Tensor:
|
644 |
+
"""mask points which should be removed"""
|
645 |
+
keep = scores > (1 - self.conf.width_confidence)
|
646 |
+
if confidences is not None: # Low-confidence points are never pruned.
|
647 |
+
keep |= confidences <= self.confidence_thresholds[layer_index]
|
648 |
+
return keep
|
649 |
+
|
650 |
+
def check_if_stop(
|
651 |
+
self,
|
652 |
+
confidences0: torch.Tensor,
|
653 |
+
confidences1: torch.Tensor,
|
654 |
+
layer_index: int,
|
655 |
+
num_points: int,
|
656 |
+
) -> torch.Tensor:
|
657 |
+
"""evaluate stopping condition"""
|
658 |
+
confidences = torch.cat([confidences0, confidences1], -1)
|
659 |
+
threshold = self.confidence_thresholds[layer_index]
|
660 |
+
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
661 |
+
return ratio_confident > self.conf.depth_confidence
|
662 |
+
|
663 |
+
def pruning_min_kpts(self, device: torch.device):
|
664 |
+
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
665 |
+
return self.pruning_keypoint_thresholds["flash"]
|
666 |
+
else:
|
667 |
+
return self.pruning_keypoint_thresholds[device.type]
|
imcui/third_party/rdd/RDD/models/backbone.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Deformable DETR
|
2 |
+
# https://github.com/fundamentalvision/Deformable-DETR
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Deformable DETR
|
5 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
6 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
7 |
+
# ------------------------------------------------------------------------
|
8 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
|
12 |
+
"""
|
13 |
+
Backbone modules.
|
14 |
+
"""
|
15 |
+
from collections import OrderedDict
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torchvision
|
20 |
+
from torch import nn
|
21 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
22 |
+
from typing import Dict, List
|
23 |
+
import torch.distributed as dist
|
24 |
+
from .position_encoding import build_position_encoding
|
25 |
+
|
26 |
+
from ..utils.misc import NestedTensor, is_main_process
|
27 |
+
|
28 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
29 |
+
"""
|
30 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
31 |
+
|
32 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
33 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
34 |
+
produce nans.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, n, eps=1e-5):
|
38 |
+
super(FrozenBatchNorm2d, self).__init__()
|
39 |
+
self.register_buffer("weight", torch.ones(n))
|
40 |
+
self.register_buffer("bias", torch.zeros(n))
|
41 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
42 |
+
self.register_buffer("running_var", torch.ones(n))
|
43 |
+
self.eps = eps
|
44 |
+
|
45 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
46 |
+
missing_keys, unexpected_keys, error_msgs):
|
47 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
48 |
+
if num_batches_tracked_key in state_dict:
|
49 |
+
del state_dict[num_batches_tracked_key]
|
50 |
+
|
51 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
52 |
+
state_dict, prefix, local_metadata, strict,
|
53 |
+
missing_keys, unexpected_keys, error_msgs)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
# move reshapes to the beginning
|
57 |
+
# to make it fuser-friendly
|
58 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
59 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
60 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
61 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
62 |
+
eps = self.eps
|
63 |
+
scale = w * (rv + eps).rsqrt()
|
64 |
+
bias = b - rm * scale
|
65 |
+
return x * scale + bias
|
66 |
+
|
67 |
+
|
68 |
+
class BackboneBase(nn.Module):
|
69 |
+
|
70 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool, n_layers = 4):
|
71 |
+
super().__init__()
|
72 |
+
for name, parameter in backbone.named_parameters():
|
73 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
74 |
+
parameter.requires_grad_(False)
|
75 |
+
if return_interm_layers:
|
76 |
+
if n_layers == 4:
|
77 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
78 |
+
self.strides = [4, 8, 16, 32]
|
79 |
+
self.num_channels = [256, 512, 1024, 2048]
|
80 |
+
elif n_layers == 3:
|
81 |
+
return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
|
82 |
+
self.strides = [8, 16, 32]
|
83 |
+
self.num_channels = [512, 1024, 2048]
|
84 |
+
else:
|
85 |
+
raise ValueError("n_layers should be 3 or 4")
|
86 |
+
|
87 |
+
else:
|
88 |
+
return_layers = {'layer4': "0"}
|
89 |
+
self.strides = [32]
|
90 |
+
self.num_channels = [2048]
|
91 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
92 |
+
|
93 |
+
def forward(self, tensor_list: NestedTensor):
|
94 |
+
xs = self.body(tensor_list.tensors)
|
95 |
+
out: Dict[str, NestedTensor] = {}
|
96 |
+
for name, x in xs.items():
|
97 |
+
m = tensor_list.mask
|
98 |
+
assert m is not None
|
99 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
100 |
+
out[name] = NestedTensor(x, mask)
|
101 |
+
return out
|
102 |
+
|
103 |
+
|
104 |
+
class Backbone(BackboneBase):
|
105 |
+
"""ResNet backbone with frozen BatchNorm."""
|
106 |
+
def __init__(self, name: str,
|
107 |
+
train_backbone: bool,
|
108 |
+
return_interm_layers: bool,
|
109 |
+
dilation: bool,
|
110 |
+
n_layers = 4):
|
111 |
+
norm_layer = FrozenBatchNorm2d
|
112 |
+
backbone = getattr(torchvision.models, name)(
|
113 |
+
replace_stride_with_dilation=[False, False, dilation],
|
114 |
+
weights='ResNet50_Weights.IMAGENET1K_V1', norm_layer=norm_layer)
|
115 |
+
assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
|
116 |
+
super().__init__(backbone, train_backbone, return_interm_layers, n_layers)
|
117 |
+
if dilation:
|
118 |
+
self.strides[-1] = self.strides[-1] // 2
|
119 |
+
|
120 |
+
class Joiner(nn.Sequential):
|
121 |
+
def __init__(self, backbone, position_embedding):
|
122 |
+
super().__init__(backbone, position_embedding)
|
123 |
+
self.strides = backbone.strides
|
124 |
+
self.num_channels = backbone.num_channels
|
125 |
+
|
126 |
+
def forward(self, tensor_list: NestedTensor):
|
127 |
+
xs = self[0](tensor_list)
|
128 |
+
out: List[NestedTensor] = []
|
129 |
+
pos = []
|
130 |
+
for name, x in sorted(xs.items()):
|
131 |
+
out.append(x)
|
132 |
+
|
133 |
+
# position encoding
|
134 |
+
for x in out:
|
135 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
136 |
+
|
137 |
+
return out, pos
|
138 |
+
|
139 |
+
|
140 |
+
def build_backbone(config):
|
141 |
+
position_embedding = build_position_encoding(config)
|
142 |
+
train_backbone = config['lr_backbone'] > 0
|
143 |
+
return_interm_layers = True
|
144 |
+
n_layers = config['num_feature_levels'] - 1
|
145 |
+
backbone = Backbone('resnet50', train_backbone, return_interm_layers, False, n_layers=n_layers)
|
146 |
+
model = Joiner(backbone, position_embedding)
|
147 |
+
return model
|
imcui/third_party/rdd/RDD/models/deformable_transformer.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Deformable DETR
|
2 |
+
# https://github.com/fundamentalvision/Deformable-DETR
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Deformable DETR
|
5 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
6 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
7 |
+
# ------------------------------------------------------------------------
|
8 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
|
12 |
+
import copy
|
13 |
+
from typing import Optional, List
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn, Tensor
|
19 |
+
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_
|
20 |
+
|
21 |
+
from ..utils.misc import inverse_sigmoid
|
22 |
+
from .ops.modules import MSDeformAttn
|
23 |
+
|
24 |
+
class MLP(nn.Module):
|
25 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
26 |
+
|
27 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
28 |
+
super().__init__()
|
29 |
+
self.num_layers = num_layers
|
30 |
+
h = [hidden_dim] * (num_layers - 1)
|
31 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
for i, layer in enumerate(self.layers):
|
35 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
def _get_clones(module, N):
|
39 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
40 |
+
|
41 |
+
class DeformableTransformerEncoderLayer(nn.Module):
|
42 |
+
def __init__(self,
|
43 |
+
d_model=256, d_ffn=1024,
|
44 |
+
dropout=0.1, activation="relu",
|
45 |
+
n_levels=4, n_heads=8, n_points=4):
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
# self attention
|
49 |
+
self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
|
50 |
+
self.dropout1 = nn.Dropout(dropout)
|
51 |
+
self.norm1 = nn.LayerNorm(d_model)
|
52 |
+
|
53 |
+
# ffn
|
54 |
+
self.linear1 = nn.Linear(d_model, d_ffn)
|
55 |
+
self.activation = _get_activation_fn(activation)
|
56 |
+
self.dropout2 = nn.Dropout(dropout)
|
57 |
+
self.linear2 = nn.Linear(d_ffn, d_model)
|
58 |
+
self.dropout3 = nn.Dropout(dropout)
|
59 |
+
self.norm2 = nn.LayerNorm(d_model)
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def with_pos_embed(tensor, pos):
|
63 |
+
return tensor if pos is None else tensor + pos
|
64 |
+
|
65 |
+
def forward_ffn(self, src):
|
66 |
+
src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
|
67 |
+
src = src + self.dropout3(src2)
|
68 |
+
src = self.norm2(src)
|
69 |
+
return src
|
70 |
+
|
71 |
+
def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None, debug=False):
|
72 |
+
# self attention
|
73 |
+
if debug:
|
74 |
+
src2, sampled_points = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
|
75 |
+
else:
|
76 |
+
src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
|
77 |
+
src = src + self.dropout1(src2)
|
78 |
+
src = self.norm1(src)
|
79 |
+
|
80 |
+
# ffn
|
81 |
+
src = self.forward_ffn(src)
|
82 |
+
if debug:
|
83 |
+
return src, sampled_points
|
84 |
+
return src
|
85 |
+
|
86 |
+
class DeformableTransformerEncoder(nn.Module):
|
87 |
+
def __init__(self, encoder_layer, num_layers):
|
88 |
+
super().__init__()
|
89 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
90 |
+
self.num_layers = num_layers
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def get_reference_points(spatial_shapes, valid_ratios, device):
|
94 |
+
reference_points_list = []
|
95 |
+
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
96 |
+
|
97 |
+
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
98 |
+
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
|
99 |
+
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
100 |
+
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
101 |
+
ref = torch.stack((ref_x, ref_y), -1)
|
102 |
+
reference_points_list.append(ref)
|
103 |
+
reference_points = torch.cat(reference_points_list, 1)
|
104 |
+
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
105 |
+
return reference_points
|
106 |
+
|
107 |
+
def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, debug=False):
|
108 |
+
output = src
|
109 |
+
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
|
110 |
+
for _, layer in enumerate(self.layers):
|
111 |
+
if debug:
|
112 |
+
output, sampled_points = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask, debug=debug)
|
113 |
+
else:
|
114 |
+
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
|
115 |
+
if debug:
|
116 |
+
return output, reference_points, sampled_points
|
117 |
+
return output
|
118 |
+
|
119 |
+
class DecoderLayer(nn.Module):
|
120 |
+
def __init__(self, d_model=256, n_head=8, dropout=0.1):
|
121 |
+
super().__init__()
|
122 |
+
self.nhead = n_head
|
123 |
+
self.dim = d_model // n_head
|
124 |
+
self.attention = LinearAttention()
|
125 |
+
self.dropout1 = nn.Dropout(dropout)
|
126 |
+
self.norm1 = nn.LayerNorm(d_model)
|
127 |
+
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
128 |
+
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
129 |
+
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
130 |
+
|
131 |
+
self.mlp = nn.Sequential(
|
132 |
+
nn.Linear(d_model*2, d_model*2, bias=False),
|
133 |
+
nn.ReLU(True),
|
134 |
+
nn.Linear(d_model*2, d_model, bias=False),
|
135 |
+
)
|
136 |
+
|
137 |
+
self.norm2 = nn.LayerNorm(d_model)
|
138 |
+
|
139 |
+
def forward(self, tgt, src, tgt_mask=None, src_mask=None):
|
140 |
+
|
141 |
+
bs = tgt.size(0)
|
142 |
+
query, key, value = tgt, src, src
|
143 |
+
|
144 |
+
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
145 |
+
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
146 |
+
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
147 |
+
|
148 |
+
tgt2 = self.attention(query, key, value, q_mask=tgt_mask, kv_mask=src_mask)
|
149 |
+
tgt2 = tgt2.view(bs, -1, self.nhead*self.dim)
|
150 |
+
tgt2 = self.norm1(self.dropout1(tgt2))
|
151 |
+
tgt2 = self.mlp(torch.cat([tgt, tgt2], dim=2))
|
152 |
+
|
153 |
+
tgt2 = self.norm2(tgt2)
|
154 |
+
|
155 |
+
return tgt + tgt2
|
156 |
+
|
157 |
+
class MLP(nn.Module):
|
158 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
159 |
+
|
160 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
161 |
+
super().__init__()
|
162 |
+
self.num_layers = num_layers
|
163 |
+
h = [hidden_dim] * (num_layers - 1)
|
164 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
for i, layer in enumerate(self.layers):
|
168 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
169 |
+
return x
|
170 |
+
|
171 |
+
|
172 |
+
class Decoder(nn.Module):
|
173 |
+
def __init__(self, decoder_layer, num_layers):
|
174 |
+
super().__init__()
|
175 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
176 |
+
self.num_layers = num_layers
|
177 |
+
|
178 |
+
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
|
179 |
+
for layer in self.layers:
|
180 |
+
tgt = layer(tgt, memory, tgt_mask=tgt_mask, src_mask=memory_mask)
|
181 |
+
|
182 |
+
return tgt
|
183 |
+
|
184 |
+
import math
|
185 |
+
|
186 |
+
class DeformableTransformer(nn.Module):
|
187 |
+
def __init__(self, d_model=256, nhead=8,
|
188 |
+
num_encoder_layers=4, dim_feedforward=1024, dropout=0.1,
|
189 |
+
activation="relu",
|
190 |
+
num_feature_levels=5, enc_n_points=8):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.d_model = d_model
|
194 |
+
self.nhead = nhead
|
195 |
+
|
196 |
+
# Encoder and Decoder
|
197 |
+
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
|
198 |
+
dropout, activation,
|
199 |
+
num_feature_levels, nhead, enc_n_points)
|
200 |
+
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
|
201 |
+
|
202 |
+
# Embedding for feature levels (multi-scale)
|
203 |
+
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
|
204 |
+
|
205 |
+
self._reset_parameters()
|
206 |
+
|
207 |
+
def _reset_parameters(self):
|
208 |
+
for p in self.parameters():
|
209 |
+
if p.dim() > 1:
|
210 |
+
nn.init.xavier_uniform_(p)
|
211 |
+
normal_(self.level_embed)
|
212 |
+
|
213 |
+
def get_valid_ratio(self, mask):
|
214 |
+
_, H, W = mask.shape
|
215 |
+
valid_H = torch.sum(~mask[:, :, 0], 1)
|
216 |
+
valid_W = torch.sum(~mask[:, 0, :], 1)
|
217 |
+
valid_ratio_h = valid_H.float() / H
|
218 |
+
valid_ratio_w = valid_W.float() / W
|
219 |
+
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
220 |
+
return valid_ratio
|
221 |
+
|
222 |
+
def forward(self, srcs, masks, pos_embeds):
|
223 |
+
|
224 |
+
# Prepare inputs for encoder
|
225 |
+
src_flatten = []
|
226 |
+
mask_flatten = []
|
227 |
+
lvl_pos_embed_flatten = []
|
228 |
+
spatial_shapes = []
|
229 |
+
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
|
230 |
+
bs, c, h, w = src.shape
|
231 |
+
spatial_shape = (h, w)
|
232 |
+
spatial_shapes.append(spatial_shape)
|
233 |
+
src = src.flatten(2).transpose(1, 2)
|
234 |
+
mask = mask.flatten(1)
|
235 |
+
pos_embed = pos_embed.flatten(2).transpose(1, 2).to(src.device)
|
236 |
+
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
|
237 |
+
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
238 |
+
src_flatten.append(src)
|
239 |
+
mask_flatten.append(mask)
|
240 |
+
src_flatten = torch.cat(src_flatten, 1)
|
241 |
+
mask_flatten = torch.cat(mask_flatten, 1)
|
242 |
+
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
243 |
+
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
|
244 |
+
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
245 |
+
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
|
246 |
+
|
247 |
+
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten, debug=False)
|
248 |
+
|
249 |
+
return memory, spatial_shapes, level_start_index
|
250 |
+
|
251 |
+
def _get_clones(module, N):
|
252 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
253 |
+
|
254 |
+
|
255 |
+
def _get_activation_fn(activation):
|
256 |
+
"""Return an activation function given a string"""
|
257 |
+
if activation == "relu":
|
258 |
+
return F.relu
|
259 |
+
if activation == "gelu":
|
260 |
+
return F.gelu
|
261 |
+
if activation == "glu":
|
262 |
+
return F.glu
|
263 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
264 |
+
|
265 |
+
|
266 |
+
def build_deforamble_transformer(config):
|
267 |
+
return DeformableTransformer(d_model=config['d_model'], nhead=config['nhead'],
|
268 |
+
num_encoder_layers=config['num_encoder_layers'], dim_feedforward=config['dim_feedforward'], dropout=config['dropout'],
|
269 |
+
activation=config['activation'],
|
270 |
+
num_feature_levels=config['num_feature_levels'], enc_n_points=config['enc_n_points'])
|
imcui/third_party/rdd/RDD/models/descriptor.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from ..utils.misc import NestedTensor, nested_tensor_from_tensor_list
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from .backbone import build_backbone
|
7 |
+
from .deformable_transformer import build_deforamble_transformer
|
8 |
+
|
9 |
+
class BasicLayer(nn.Module):
|
10 |
+
"""
|
11 |
+
Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
|
12 |
+
"""
|
13 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
|
14 |
+
super().__init__()
|
15 |
+
self.layer = nn.Sequential(
|
16 |
+
nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
|
17 |
+
nn.BatchNorm2d(out_channels, affine=False),
|
18 |
+
nn.ReLU(inplace = False),
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return self.layer(x)
|
23 |
+
|
24 |
+
class RDD_Descriptor(nn.Module):
|
25 |
+
def __init__(self, backbone, transformer, num_feature_levels):
|
26 |
+
super().__init__()
|
27 |
+
self.transformer = transformer
|
28 |
+
self.hidden_dim = transformer.d_model
|
29 |
+
self.num_feature_levels = num_feature_levels
|
30 |
+
|
31 |
+
self.matchibility_head = nn.Sequential(
|
32 |
+
BasicLayer(256, 128, 1, padding=0),
|
33 |
+
BasicLayer(128, 64, 1, padding=0),
|
34 |
+
nn.Conv2d (64, 1, 1),
|
35 |
+
nn.Sigmoid()
|
36 |
+
)
|
37 |
+
|
38 |
+
if num_feature_levels > 1:
|
39 |
+
num_backbone_outs = len(backbone.strides)
|
40 |
+
input_proj_list = []
|
41 |
+
for _ in range(num_backbone_outs):
|
42 |
+
in_channels = backbone.num_channels[_]
|
43 |
+
input_proj_list.append(nn.Sequential(
|
44 |
+
nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1),
|
45 |
+
nn.GroupNorm(32, self.hidden_dim),
|
46 |
+
))
|
47 |
+
for _ in range(num_feature_levels - num_backbone_outs):
|
48 |
+
input_proj_list.append(nn.Sequential(
|
49 |
+
nn.Conv2d(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1),
|
50 |
+
nn.GroupNorm(32, self.hidden_dim),
|
51 |
+
))
|
52 |
+
in_channels = self.hidden_dim
|
53 |
+
self.input_proj = nn.ModuleList(input_proj_list)
|
54 |
+
else:
|
55 |
+
self.input_proj = nn.ModuleList([
|
56 |
+
nn.Sequential(
|
57 |
+
nn.Conv2d(backbone.num_channels[0], self.hidden_dim, kernel_size=1),
|
58 |
+
nn.GroupNorm(32, self.hidden_dim),
|
59 |
+
)])
|
60 |
+
self.backbone = backbone
|
61 |
+
self.stride = backbone.strides[0]
|
62 |
+
for proj in self.input_proj:
|
63 |
+
nn.init.xavier_uniform_(proj[0].weight, gain=1)
|
64 |
+
nn.init.constant_(proj[0].bias, 0)
|
65 |
+
|
66 |
+
def forward(self, samples: NestedTensor):
|
67 |
+
|
68 |
+
if not isinstance(samples, NestedTensor):
|
69 |
+
samples = nested_tensor_from_tensor_list(samples)
|
70 |
+
|
71 |
+
features, pos = self.backbone(samples)
|
72 |
+
|
73 |
+
srcs = []
|
74 |
+
masks = []
|
75 |
+
for l, feat in enumerate(features):
|
76 |
+
src, mask = feat.decompose()
|
77 |
+
srcs.append(self.input_proj[l](src))
|
78 |
+
masks.append(mask)
|
79 |
+
assert mask is not None
|
80 |
+
if self.num_feature_levels > len(srcs):
|
81 |
+
_len_srcs = len(srcs)
|
82 |
+
for l in range(_len_srcs, self.num_feature_levels):
|
83 |
+
if l == _len_srcs:
|
84 |
+
src = self.input_proj[l](features[-1].tensors)
|
85 |
+
else:
|
86 |
+
src = self.input_proj[l](srcs[-1])
|
87 |
+
m = samples.mask
|
88 |
+
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0]
|
89 |
+
pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype)
|
90 |
+
srcs.append(src)
|
91 |
+
masks.append(mask)
|
92 |
+
pos.append(pos_l)
|
93 |
+
|
94 |
+
flatten_feats, spatial_shapes, level_start_index = self.transformer(srcs, masks, pos)
|
95 |
+
# Reshape the flattened features back to the original spatial shapes
|
96 |
+
feats = []
|
97 |
+
level_start_index = torch.cat((level_start_index, torch.tensor([flatten_feats.shape[1]+1]).to(level_start_index.device)))
|
98 |
+
for i, shape in enumerate(spatial_shapes):
|
99 |
+
assert len(shape) == 2
|
100 |
+
temp = flatten_feats[:, level_start_index[i] : level_start_index[i+1], :]
|
101 |
+
feats.append(temp.transpose(1, 2).view(-1, self.hidden_dim, *shape))
|
102 |
+
|
103 |
+
# Sum up the features from different levels
|
104 |
+
final_feature = feats[0]
|
105 |
+
for feat in feats[1:]:
|
106 |
+
final_feature = final_feature + F.interpolate(feat, size=final_feature.shape[-2:], mode='bilinear', align_corners=True)
|
107 |
+
|
108 |
+
matchibility = self.matchibility_head(final_feature)
|
109 |
+
|
110 |
+
return final_feature, matchibility
|
111 |
+
|
112 |
+
|
113 |
+
def build_descriptor(config):
|
114 |
+
backbone = build_backbone(config)
|
115 |
+
transformer = build_deforamble_transformer(config)
|
116 |
+
return RDD_Descriptor(backbone, transformer, config['num_feature_levels'])
|
imcui/third_party/rdd/RDD/models/detector.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchvision.models import resnet
|
5 |
+
from typing import Optional, Callable
|
6 |
+
from ..utils.misc import NestedTensor
|
7 |
+
|
8 |
+
class ConvBlock(nn.Module):
|
9 |
+
def __init__(self, in_channels, out_channels,
|
10 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
11 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None):
|
12 |
+
super().__init__()
|
13 |
+
if gate is None:
|
14 |
+
self.gate = nn.ReLU(inplace=False)
|
15 |
+
else:
|
16 |
+
self.gate = gate
|
17 |
+
if norm_layer is None:
|
18 |
+
norm_layer = nn.BatchNorm2d
|
19 |
+
self.conv1 = resnet.conv3x3(in_channels, out_channels)
|
20 |
+
self.bn1 = norm_layer(out_channels)
|
21 |
+
self.conv2 = resnet.conv3x3(out_channels, out_channels)
|
22 |
+
self.bn2 = norm_layer(out_channels)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
|
26 |
+
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
|
27 |
+
return x
|
28 |
+
|
29 |
+
class ResBlock(nn.Module):
|
30 |
+
expansion: int = 1
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
inplanes: int,
|
35 |
+
planes: int,
|
36 |
+
stride: int = 1,
|
37 |
+
downsample: Optional[nn.Module] = None,
|
38 |
+
groups: int = 1,
|
39 |
+
base_width: int = 64,
|
40 |
+
dilation: int = 1,
|
41 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
42 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
43 |
+
) -> None:
|
44 |
+
super(ResBlock, self).__init__()
|
45 |
+
if gate is None:
|
46 |
+
self.gate = nn.ReLU(inplace=False)
|
47 |
+
else:
|
48 |
+
self.gate = gate
|
49 |
+
if norm_layer is None:
|
50 |
+
norm_layer = nn.BatchNorm2d
|
51 |
+
if groups != 1 or base_width != 64:
|
52 |
+
raise ValueError('ResBlock only supports groups=1 and base_width=64')
|
53 |
+
if dilation > 1:
|
54 |
+
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
55 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
56 |
+
self.conv1 = resnet.conv3x3(inplanes, planes, stride)
|
57 |
+
self.bn1 = norm_layer(planes)
|
58 |
+
self.conv2 = resnet.conv3x3(planes, planes)
|
59 |
+
self.bn2 = norm_layer(planes)
|
60 |
+
self.downsample = downsample
|
61 |
+
self.stride = stride
|
62 |
+
|
63 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
64 |
+
identity = x
|
65 |
+
|
66 |
+
out = self.conv1(x)
|
67 |
+
out = self.bn1(out)
|
68 |
+
out = self.gate(out)
|
69 |
+
|
70 |
+
out = self.conv2(out)
|
71 |
+
out = self.bn2(out)
|
72 |
+
|
73 |
+
if self.downsample is not None:
|
74 |
+
identity = self.downsample(x)
|
75 |
+
|
76 |
+
out = out + identity
|
77 |
+
out = self.gate(out)
|
78 |
+
|
79 |
+
return out
|
80 |
+
|
81 |
+
class RDD_detector(nn.Module):
|
82 |
+
def __init__(self, block_dims, hidden_dim=128):
|
83 |
+
super().__init__()
|
84 |
+
self.gate = nn.ReLU(inplace=False)
|
85 |
+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
86 |
+
self.pool4 = nn.MaxPool2d(kernel_size=4, stride=4)
|
87 |
+
self.block1 = ConvBlock(3, block_dims[0], self.gate, nn.BatchNorm2d)
|
88 |
+
self.block2 = ResBlock(inplanes=block_dims[0], planes=block_dims[1], stride=1,
|
89 |
+
downsample=nn.Conv2d(block_dims[0], block_dims[1], 1),
|
90 |
+
gate=self.gate,
|
91 |
+
norm_layer=nn.BatchNorm2d)
|
92 |
+
self.block3 = ResBlock(inplanes=block_dims[1], planes=block_dims[2], stride=1,
|
93 |
+
downsample=nn.Conv2d(block_dims[1], block_dims[2], 1),
|
94 |
+
gate=self.gate,
|
95 |
+
norm_layer=nn.BatchNorm2d)
|
96 |
+
self.block4 = ResBlock(inplanes=block_dims[2], planes=block_dims[3], stride=1,
|
97 |
+
downsample=nn.Conv2d(block_dims[2], block_dims[3], 1),
|
98 |
+
gate=self.gate,
|
99 |
+
norm_layer=nn.BatchNorm2d)
|
100 |
+
|
101 |
+
self.conv1 = resnet.conv1x1(block_dims[0], hidden_dim // 4)
|
102 |
+
self.conv2 = resnet.conv1x1(block_dims[1], hidden_dim // 4)
|
103 |
+
self.conv3 = resnet.conv1x1(block_dims[2], hidden_dim // 4)
|
104 |
+
self.conv4 = resnet.conv1x1(block_dims[3], hidden_dim // 4)
|
105 |
+
|
106 |
+
self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
107 |
+
self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
108 |
+
self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
|
109 |
+
self.upsample32 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
|
110 |
+
|
111 |
+
self.convhead2 = nn.Sequential(
|
112 |
+
resnet.conv1x1(hidden_dim, 1),
|
113 |
+
nn.Sigmoid()
|
114 |
+
)
|
115 |
+
|
116 |
+
def forward(self, samples: NestedTensor):
|
117 |
+
x1 = self.block1(samples.tensors)
|
118 |
+
x2 = self.pool2(x1)
|
119 |
+
x2 = self.block2(x2) # B x c2 x H/2 x W/2
|
120 |
+
x3 = self.pool4(x2)
|
121 |
+
x3 = self.block3(x3) # B x c3 x H/8 x W/8
|
122 |
+
x4 = self.pool4(x3)
|
123 |
+
x4 = self.block4(x4)
|
124 |
+
|
125 |
+
x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
|
126 |
+
x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
|
127 |
+
x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
|
128 |
+
x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
|
129 |
+
|
130 |
+
x2_up = self.upsample2(x2)
|
131 |
+
x3_up = self.upsample8(x3)
|
132 |
+
x4_up = self.upsample32(x4)
|
133 |
+
|
134 |
+
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
|
135 |
+
scoremap = self.convhead2(x1234)
|
136 |
+
|
137 |
+
return scoremap
|
138 |
+
|
139 |
+
def build_detector(config):
|
140 |
+
block_dims = config['block_dims']
|
141 |
+
return RDD_detector(block_dims, block_dims[-1])
|
imcui/third_party/rdd/RDD/models/interpolator.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
|
3 |
+
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
class InterpolateSparse2d(nn.Module):
|
11 |
+
""" Efficiently interpolate tensor at given sparse 2D positions. """
|
12 |
+
def __init__(self, mode = 'bilinear', align_corners = True):
|
13 |
+
super().__init__()
|
14 |
+
self.mode = mode
|
15 |
+
self.align_corners = align_corners
|
16 |
+
|
17 |
+
def normgrid(self, x, H, W):
|
18 |
+
""" Normalize coords to [-1,1]. """
|
19 |
+
return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
|
20 |
+
|
21 |
+
def forward(self, x, pos, H, W):
|
22 |
+
"""
|
23 |
+
Input
|
24 |
+
x: [B, C, H, W] feature tensor
|
25 |
+
pos: [B, N, 2] tensor of positions
|
26 |
+
H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
|
27 |
+
|
28 |
+
Returns
|
29 |
+
[B, N, C] sampled channels at 2d positions
|
30 |
+
"""
|
31 |
+
grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
|
32 |
+
x = F.grid_sample(x, grid, mode = self.mode , align_corners = self.align_corners)
|
33 |
+
return x.permute(0,2,3,1).squeeze(-2)
|
imcui/third_party/rdd/RDD/models/ops/functions/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from .ms_deform_attn_func import MSDeformAttnFunction
|
13 |
+
|
imcui/third_party/rdd/RDD/models/ops/functions/ms_deform_attn_func.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import print_function
|
14 |
+
from __future__ import division
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.autograd import Function
|
19 |
+
from torch.autograd.function import once_differentiable
|
20 |
+
|
21 |
+
try:
|
22 |
+
import MultiScaleDeformableAttention as MSDA
|
23 |
+
except ModuleNotFoundError as e:
|
24 |
+
info_string = (
|
25 |
+
"\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
|
26 |
+
"\t`cd mask2former/modeling/pixel_decoder/ops`\n"
|
27 |
+
"\t`sh make.sh`\n"
|
28 |
+
)
|
29 |
+
print(info_string)
|
30 |
+
|
31 |
+
|
32 |
+
class MSDeformAttnFunction(Function):
|
33 |
+
@staticmethod
|
34 |
+
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
|
35 |
+
ctx.im2col_step = im2col_step
|
36 |
+
output = MSDA.ms_deform_attn_forward(
|
37 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
|
38 |
+
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
|
39 |
+
return output
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
@once_differentiable
|
43 |
+
def backward(ctx, grad_output):
|
44 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
45 |
+
grad_value, grad_sampling_loc, grad_attn_weight = \
|
46 |
+
MSDA.ms_deform_attn_backward(
|
47 |
+
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
|
48 |
+
|
49 |
+
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
50 |
+
|
51 |
+
|
52 |
+
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
53 |
+
# for debug and test only,
|
54 |
+
# need to use cuda version instead
|
55 |
+
N_, S_, M_, D_ = value.shape
|
56 |
+
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
57 |
+
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
58 |
+
sampling_grids = 2 * sampling_locations - 1
|
59 |
+
sampling_value_list = []
|
60 |
+
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
61 |
+
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
62 |
+
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
63 |
+
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
64 |
+
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
65 |
+
# N_*M_, D_, Lq_, P_
|
66 |
+
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
67 |
+
mode='bilinear', padding_mode='zeros', align_corners=False)
|
68 |
+
sampling_value_list.append(sampling_value_l_)
|
69 |
+
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
70 |
+
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
|
71 |
+
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
72 |
+
return output.transpose(1, 2).contiguous()
|
imcui/third_party/rdd/RDD/models/ops/make.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# ------------------------------------------------------------------------------------------------
|
3 |
+
# Deformable DETR
|
4 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
# ------------------------------------------------------------------------------------------------
|
7 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
# ------------------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
11 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
12 |
+
|
13 |
+
python setup.py build install
|
imcui/third_party/rdd/RDD/models/ops/modules/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from .ms_deform_attn import MSDeformAttn
|
imcui/third_party/rdd/RDD/models/ops/modules/ms_deform_attn.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import print_function
|
14 |
+
from __future__ import division
|
15 |
+
|
16 |
+
import warnings
|
17 |
+
import math
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from torch.nn.init import xavier_uniform_, constant_
|
23 |
+
|
24 |
+
from ..functions import MSDeformAttnFunction
|
25 |
+
from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
|
26 |
+
|
27 |
+
|
28 |
+
def _is_power_of_2(n):
|
29 |
+
if (not isinstance(n, int)) or (n < 0):
|
30 |
+
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
31 |
+
return (n & (n-1) == 0) and n != 0
|
32 |
+
|
33 |
+
|
34 |
+
class MSDeformAttn(nn.Module):
|
35 |
+
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
36 |
+
"""
|
37 |
+
Multi-Scale Deformable Attention Module
|
38 |
+
:param d_model hidden dimension
|
39 |
+
:param n_levels number of feature levels
|
40 |
+
:param n_heads number of attention heads
|
41 |
+
:param n_points number of sampling points per attention head per feature level
|
42 |
+
"""
|
43 |
+
super().__init__()
|
44 |
+
if d_model % n_heads != 0:
|
45 |
+
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
46 |
+
_d_per_head = d_model // n_heads
|
47 |
+
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
48 |
+
if not _is_power_of_2(_d_per_head):
|
49 |
+
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
50 |
+
"which is more efficient in our CUDA implementation.")
|
51 |
+
|
52 |
+
self.im2col_step = 128
|
53 |
+
|
54 |
+
self.d_model = d_model
|
55 |
+
self.n_levels = n_levels
|
56 |
+
self.n_heads = n_heads
|
57 |
+
self.n_points = n_points
|
58 |
+
|
59 |
+
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
60 |
+
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
61 |
+
self.value_proj = nn.Linear(d_model, d_model)
|
62 |
+
self.output_proj = nn.Linear(d_model, d_model)
|
63 |
+
|
64 |
+
self._reset_parameters()
|
65 |
+
|
66 |
+
def _reset_parameters(self):
|
67 |
+
constant_(self.sampling_offsets.weight.data, 0.)
|
68 |
+
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
69 |
+
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
70 |
+
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
|
71 |
+
for i in range(self.n_points):
|
72 |
+
grid_init[:, :, i, :] *= i + 1
|
73 |
+
with torch.no_grad():
|
74 |
+
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
75 |
+
constant_(self.attention_weights.weight.data, 0.)
|
76 |
+
constant_(self.attention_weights.bias.data, 0.)
|
77 |
+
xavier_uniform_(self.value_proj.weight.data)
|
78 |
+
constant_(self.value_proj.bias.data, 0.)
|
79 |
+
xavier_uniform_(self.output_proj.weight.data)
|
80 |
+
constant_(self.output_proj.bias.data, 0.)
|
81 |
+
|
82 |
+
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
|
83 |
+
"""
|
84 |
+
:param query (N, Length_{query}, C)
|
85 |
+
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
86 |
+
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
87 |
+
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
88 |
+
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
89 |
+
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
90 |
+
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
|
91 |
+
|
92 |
+
:return output (N, Length_{query}, C)
|
93 |
+
"""
|
94 |
+
N, Len_q, _ = query.shape
|
95 |
+
N, Len_in, _ = input_flatten.shape
|
96 |
+
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
97 |
+
|
98 |
+
value = self.value_proj(input_flatten)
|
99 |
+
if input_padding_mask is not None:
|
100 |
+
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
101 |
+
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
102 |
+
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
103 |
+
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
|
104 |
+
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
|
105 |
+
# N, Len_q, n_heads, n_levels, n_points, 2
|
106 |
+
if reference_points.shape[-1] == 2:
|
107 |
+
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
108 |
+
sampling_locations = reference_points[:, :, None, :, None, :] \
|
109 |
+
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
110 |
+
elif reference_points.shape[-1] == 4:
|
111 |
+
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
112 |
+
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
113 |
+
else:
|
114 |
+
raise ValueError(
|
115 |
+
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
|
116 |
+
try:
|
117 |
+
output = MSDeformAttnFunction.apply(
|
118 |
+
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
|
119 |
+
except:
|
120 |
+
# CPU
|
121 |
+
output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
|
122 |
+
# # For FLOPs calculation only
|
123 |
+
# output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
|
124 |
+
output = self.output_proj(output)
|
125 |
+
return output
|
imcui/third_party/rdd/RDD/models/ops/setup.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
import os
|
13 |
+
import glob
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from torch.utils.cpp_extension import CUDA_HOME
|
18 |
+
from torch.utils.cpp_extension import CppExtension
|
19 |
+
from torch.utils.cpp_extension import CUDAExtension
|
20 |
+
|
21 |
+
from setuptools import find_packages
|
22 |
+
from setuptools import setup
|
23 |
+
|
24 |
+
requirements = ["torch", "torchvision"]
|
25 |
+
|
26 |
+
def get_extensions():
|
27 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
extensions_dir = os.path.join(this_dir, "src")
|
29 |
+
|
30 |
+
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
31 |
+
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
|
32 |
+
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
|
33 |
+
|
34 |
+
sources = main_file + source_cpu
|
35 |
+
extension = CppExtension
|
36 |
+
extra_compile_args = {"cxx": []}
|
37 |
+
define_macros = []
|
38 |
+
|
39 |
+
# Force cuda since torch ask for a device, not if cuda is in fact available.
|
40 |
+
if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
|
41 |
+
extension = CUDAExtension
|
42 |
+
sources += source_cuda
|
43 |
+
define_macros += [("WITH_CUDA", None)]
|
44 |
+
extra_compile_args["nvcc"] = [
|
45 |
+
"-DCUDA_HAS_FP16=1",
|
46 |
+
"-D__CUDA_NO_HALF_OPERATORS__",
|
47 |
+
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
48 |
+
"-D__CUDA_NO_HALF2_OPERATORS__",
|
49 |
+
]
|
50 |
+
else:
|
51 |
+
if CUDA_HOME is None:
|
52 |
+
raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
|
53 |
+
else:
|
54 |
+
raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
|
55 |
+
|
56 |
+
sources = [os.path.join(extensions_dir, s) for s in sources]
|
57 |
+
include_dirs = [extensions_dir]
|
58 |
+
ext_modules = [
|
59 |
+
extension(
|
60 |
+
"MultiScaleDeformableAttention",
|
61 |
+
sources,
|
62 |
+
include_dirs=include_dirs,
|
63 |
+
define_macros=define_macros,
|
64 |
+
extra_compile_args=extra_compile_args,
|
65 |
+
)
|
66 |
+
]
|
67 |
+
return ext_modules
|
68 |
+
|
69 |
+
setup(
|
70 |
+
name="MultiScaleDeformableAttention",
|
71 |
+
version="1.0",
|
72 |
+
author="Weijie Su",
|
73 |
+
url="https://github.com/fundamentalvision/Deformable-DETR",
|
74 |
+
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
|
75 |
+
packages=find_packages(exclude=("configs", "tests",)),
|
76 |
+
ext_modules=get_extensions(),
|
77 |
+
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
78 |
+
)
|
imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.cpp
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
/*!
|
12 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
13 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
14 |
+
*/
|
15 |
+
|
16 |
+
#include <vector>
|
17 |
+
|
18 |
+
#include <ATen/ATen.h>
|
19 |
+
#include <ATen/cuda/CUDAContext.h>
|
20 |
+
|
21 |
+
|
22 |
+
at::Tensor
|
23 |
+
ms_deform_attn_cpu_forward(
|
24 |
+
const at::Tensor &value,
|
25 |
+
const at::Tensor &spatial_shapes,
|
26 |
+
const at::Tensor &level_start_index,
|
27 |
+
const at::Tensor &sampling_loc,
|
28 |
+
const at::Tensor &attn_weight,
|
29 |
+
const int im2col_step)
|
30 |
+
{
|
31 |
+
AT_ERROR("Not implement on cpu");
|
32 |
+
}
|
33 |
+
|
34 |
+
std::vector<at::Tensor>
|
35 |
+
ms_deform_attn_cpu_backward(
|
36 |
+
const at::Tensor &value,
|
37 |
+
const at::Tensor &spatial_shapes,
|
38 |
+
const at::Tensor &level_start_index,
|
39 |
+
const at::Tensor &sampling_loc,
|
40 |
+
const at::Tensor &attn_weight,
|
41 |
+
const at::Tensor &grad_output,
|
42 |
+
const int im2col_step)
|
43 |
+
{
|
44 |
+
AT_ERROR("Not implement on cpu");
|
45 |
+
}
|
46 |
+
|
imcui/third_party/rdd/RDD/models/ops/src/cpu/ms_deform_attn_cpu.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
/*!
|
12 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
13 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
14 |
+
*/
|
15 |
+
|
16 |
+
#pragma once
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
at::Tensor
|
20 |
+
ms_deform_attn_cpu_forward(
|
21 |
+
const at::Tensor &value,
|
22 |
+
const at::Tensor &spatial_shapes,
|
23 |
+
const at::Tensor &level_start_index,
|
24 |
+
const at::Tensor &sampling_loc,
|
25 |
+
const at::Tensor &attn_weight,
|
26 |
+
const int im2col_step);
|
27 |
+
|
28 |
+
std::vector<at::Tensor>
|
29 |
+
ms_deform_attn_cpu_backward(
|
30 |
+
const at::Tensor &value,
|
31 |
+
const at::Tensor &spatial_shapes,
|
32 |
+
const at::Tensor &level_start_index,
|
33 |
+
const at::Tensor &sampling_loc,
|
34 |
+
const at::Tensor &attn_weight,
|
35 |
+
const at::Tensor &grad_output,
|
36 |
+
const int im2col_step);
|
37 |
+
|
38 |
+
|
imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.cu
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
/*!
|
12 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
13 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
14 |
+
*/
|
15 |
+
|
16 |
+
#include <vector>
|
17 |
+
#include "cuda/ms_deform_im2col_cuda.cuh"
|
18 |
+
|
19 |
+
#include <ATen/ATen.h>
|
20 |
+
#include <ATen/cuda/CUDAContext.h>
|
21 |
+
#include <cuda.h>
|
22 |
+
#include <cuda_runtime.h>
|
23 |
+
|
24 |
+
|
25 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
26 |
+
const at::Tensor &value,
|
27 |
+
const at::Tensor &spatial_shapes,
|
28 |
+
const at::Tensor &level_start_index,
|
29 |
+
const at::Tensor &sampling_loc,
|
30 |
+
const at::Tensor &attn_weight,
|
31 |
+
const int im2col_step)
|
32 |
+
{
|
33 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
34 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
35 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
36 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
37 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
38 |
+
|
39 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
40 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
41 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
42 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
43 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
44 |
+
|
45 |
+
const int batch = value.size(0);
|
46 |
+
const int spatial_size = value.size(1);
|
47 |
+
const int num_heads = value.size(2);
|
48 |
+
const int channels = value.size(3);
|
49 |
+
|
50 |
+
const int num_levels = spatial_shapes.size(0);
|
51 |
+
|
52 |
+
const int num_query = sampling_loc.size(1);
|
53 |
+
const int num_point = sampling_loc.size(4);
|
54 |
+
|
55 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
56 |
+
|
57 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
58 |
+
|
59 |
+
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
60 |
+
|
61 |
+
const int batch_n = im2col_step_;
|
62 |
+
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
63 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
64 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
65 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
66 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
67 |
+
{
|
68 |
+
auto columns = output_n.select(0, n);
|
69 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
70 |
+
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
71 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
72 |
+
spatial_shapes.data<int64_t>(),
|
73 |
+
level_start_index.data<int64_t>(),
|
74 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
75 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
76 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
77 |
+
columns.data<scalar_t>());
|
78 |
+
|
79 |
+
}));
|
80 |
+
}
|
81 |
+
|
82 |
+
output = output.view({batch, num_query, num_heads*channels});
|
83 |
+
|
84 |
+
return output;
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
89 |
+
const at::Tensor &value,
|
90 |
+
const at::Tensor &spatial_shapes,
|
91 |
+
const at::Tensor &level_start_index,
|
92 |
+
const at::Tensor &sampling_loc,
|
93 |
+
const at::Tensor &attn_weight,
|
94 |
+
const at::Tensor &grad_output,
|
95 |
+
const int im2col_step)
|
96 |
+
{
|
97 |
+
|
98 |
+
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
99 |
+
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
100 |
+
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
101 |
+
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
102 |
+
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
103 |
+
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
104 |
+
|
105 |
+
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
106 |
+
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
107 |
+
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
108 |
+
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
109 |
+
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
110 |
+
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
111 |
+
|
112 |
+
const int batch = value.size(0);
|
113 |
+
const int spatial_size = value.size(1);
|
114 |
+
const int num_heads = value.size(2);
|
115 |
+
const int channels = value.size(3);
|
116 |
+
|
117 |
+
const int num_levels = spatial_shapes.size(0);
|
118 |
+
|
119 |
+
const int num_query = sampling_loc.size(1);
|
120 |
+
const int num_point = sampling_loc.size(4);
|
121 |
+
|
122 |
+
const int im2col_step_ = std::min(batch, im2col_step);
|
123 |
+
|
124 |
+
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
125 |
+
|
126 |
+
auto grad_value = at::zeros_like(value);
|
127 |
+
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
128 |
+
auto grad_attn_weight = at::zeros_like(attn_weight);
|
129 |
+
|
130 |
+
const int batch_n = im2col_step_;
|
131 |
+
auto per_value_size = spatial_size * num_heads * channels;
|
132 |
+
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
133 |
+
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
134 |
+
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
135 |
+
|
136 |
+
for (int n = 0; n < batch/im2col_step_; ++n)
|
137 |
+
{
|
138 |
+
auto grad_output_g = grad_output_n.select(0, n);
|
139 |
+
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
140 |
+
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
141 |
+
grad_output_g.data<scalar_t>(),
|
142 |
+
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
143 |
+
spatial_shapes.data<int64_t>(),
|
144 |
+
level_start_index.data<int64_t>(),
|
145 |
+
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
146 |
+
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
147 |
+
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
148 |
+
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
149 |
+
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
150 |
+
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
151 |
+
|
152 |
+
}));
|
153 |
+
}
|
154 |
+
|
155 |
+
return {
|
156 |
+
grad_value, grad_sampling_loc, grad_attn_weight
|
157 |
+
};
|
158 |
+
}
|
imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_attn_cuda.h
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
/*!
|
12 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
13 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
14 |
+
*/
|
15 |
+
|
16 |
+
#pragma once
|
17 |
+
#include <torch/extension.h>
|
18 |
+
|
19 |
+
at::Tensor ms_deform_attn_cuda_forward(
|
20 |
+
const at::Tensor &value,
|
21 |
+
const at::Tensor &spatial_shapes,
|
22 |
+
const at::Tensor &level_start_index,
|
23 |
+
const at::Tensor &sampling_loc,
|
24 |
+
const at::Tensor &attn_weight,
|
25 |
+
const int im2col_step);
|
26 |
+
|
27 |
+
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
28 |
+
const at::Tensor &value,
|
29 |
+
const at::Tensor &spatial_shapes,
|
30 |
+
const at::Tensor &level_start_index,
|
31 |
+
const at::Tensor &sampling_loc,
|
32 |
+
const at::Tensor &attn_weight,
|
33 |
+
const at::Tensor &grad_output,
|
34 |
+
const int im2col_step);
|
35 |
+
|
imcui/third_party/rdd/RDD/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
ADDED
@@ -0,0 +1,1332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************
|
7 |
+
* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
|
8 |
+
* Copyright (c) 2018 Microsoft
|
9 |
+
**************************************************************************
|
10 |
+
*/
|
11 |
+
|
12 |
+
/*!
|
13 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
14 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <cstdio>
|
18 |
+
#include <algorithm>
|
19 |
+
#include <cstring>
|
20 |
+
|
21 |
+
#include <ATen/ATen.h>
|
22 |
+
#include <ATen/cuda/CUDAContext.h>
|
23 |
+
|
24 |
+
#include <THC/THCAtomics.cuh>
|
25 |
+
|
26 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
27 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
28 |
+
i < (n); \
|
29 |
+
i += blockDim.x * gridDim.x)
|
30 |
+
|
31 |
+
const int CUDA_NUM_THREADS = 1024;
|
32 |
+
inline int GET_BLOCKS(const int N, const int num_threads)
|
33 |
+
{
|
34 |
+
return (N + num_threads - 1) / num_threads;
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
template <typename scalar_t>
|
39 |
+
__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
|
40 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
41 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c)
|
42 |
+
{
|
43 |
+
const int h_low = floor(h);
|
44 |
+
const int w_low = floor(w);
|
45 |
+
const int h_high = h_low + 1;
|
46 |
+
const int w_high = w_low + 1;
|
47 |
+
|
48 |
+
const scalar_t lh = h - h_low;
|
49 |
+
const scalar_t lw = w - w_low;
|
50 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
51 |
+
|
52 |
+
const int w_stride = nheads * channels;
|
53 |
+
const int h_stride = width * w_stride;
|
54 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
55 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
56 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
57 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
58 |
+
const int base_ptr = m * channels + c;
|
59 |
+
|
60 |
+
scalar_t v1 = 0;
|
61 |
+
if (h_low >= 0 && w_low >= 0)
|
62 |
+
{
|
63 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
64 |
+
v1 = bottom_data[ptr1];
|
65 |
+
}
|
66 |
+
scalar_t v2 = 0;
|
67 |
+
if (h_low >= 0 && w_high <= width - 1)
|
68 |
+
{
|
69 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
70 |
+
v2 = bottom_data[ptr2];
|
71 |
+
}
|
72 |
+
scalar_t v3 = 0;
|
73 |
+
if (h_high <= height - 1 && w_low >= 0)
|
74 |
+
{
|
75 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
76 |
+
v3 = bottom_data[ptr3];
|
77 |
+
}
|
78 |
+
scalar_t v4 = 0;
|
79 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
80 |
+
{
|
81 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
82 |
+
v4 = bottom_data[ptr4];
|
83 |
+
}
|
84 |
+
|
85 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
86 |
+
|
87 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
88 |
+
return val;
|
89 |
+
}
|
90 |
+
|
91 |
+
|
92 |
+
template <typename scalar_t>
|
93 |
+
__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
|
94 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
95 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
96 |
+
const scalar_t &top_grad,
|
97 |
+
const scalar_t &attn_weight,
|
98 |
+
scalar_t* &grad_value,
|
99 |
+
scalar_t* grad_sampling_loc,
|
100 |
+
scalar_t* grad_attn_weight)
|
101 |
+
{
|
102 |
+
const int h_low = floor(h);
|
103 |
+
const int w_low = floor(w);
|
104 |
+
const int h_high = h_low + 1;
|
105 |
+
const int w_high = w_low + 1;
|
106 |
+
|
107 |
+
const scalar_t lh = h - h_low;
|
108 |
+
const scalar_t lw = w - w_low;
|
109 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
110 |
+
|
111 |
+
const int w_stride = nheads * channels;
|
112 |
+
const int h_stride = width * w_stride;
|
113 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
114 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
115 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
116 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
117 |
+
const int base_ptr = m * channels + c;
|
118 |
+
|
119 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
120 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
121 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
122 |
+
|
123 |
+
scalar_t v1 = 0;
|
124 |
+
if (h_low >= 0 && w_low >= 0)
|
125 |
+
{
|
126 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
127 |
+
v1 = bottom_data[ptr1];
|
128 |
+
grad_h_weight -= hw * v1;
|
129 |
+
grad_w_weight -= hh * v1;
|
130 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
131 |
+
}
|
132 |
+
scalar_t v2 = 0;
|
133 |
+
if (h_low >= 0 && w_high <= width - 1)
|
134 |
+
{
|
135 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
136 |
+
v2 = bottom_data[ptr2];
|
137 |
+
grad_h_weight -= lw * v2;
|
138 |
+
grad_w_weight += hh * v2;
|
139 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
140 |
+
}
|
141 |
+
scalar_t v3 = 0;
|
142 |
+
if (h_high <= height - 1 && w_low >= 0)
|
143 |
+
{
|
144 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
145 |
+
v3 = bottom_data[ptr3];
|
146 |
+
grad_h_weight += hw * v3;
|
147 |
+
grad_w_weight -= lh * v3;
|
148 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
149 |
+
}
|
150 |
+
scalar_t v4 = 0;
|
151 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
152 |
+
{
|
153 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
154 |
+
v4 = bottom_data[ptr4];
|
155 |
+
grad_h_weight += lw * v4;
|
156 |
+
grad_w_weight += lh * v4;
|
157 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
158 |
+
}
|
159 |
+
|
160 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
161 |
+
*grad_attn_weight = top_grad * val;
|
162 |
+
*grad_sampling_loc = width * grad_w_weight * top_grad_value;
|
163 |
+
*(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
template <typename scalar_t>
|
168 |
+
__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
|
169 |
+
const int &height, const int &width, const int &nheads, const int &channels,
|
170 |
+
const scalar_t &h, const scalar_t &w, const int &m, const int &c,
|
171 |
+
const scalar_t &top_grad,
|
172 |
+
const scalar_t &attn_weight,
|
173 |
+
scalar_t* &grad_value,
|
174 |
+
scalar_t* grad_sampling_loc,
|
175 |
+
scalar_t* grad_attn_weight)
|
176 |
+
{
|
177 |
+
const int h_low = floor(h);
|
178 |
+
const int w_low = floor(w);
|
179 |
+
const int h_high = h_low + 1;
|
180 |
+
const int w_high = w_low + 1;
|
181 |
+
|
182 |
+
const scalar_t lh = h - h_low;
|
183 |
+
const scalar_t lw = w - w_low;
|
184 |
+
const scalar_t hh = 1 - lh, hw = 1 - lw;
|
185 |
+
|
186 |
+
const int w_stride = nheads * channels;
|
187 |
+
const int h_stride = width * w_stride;
|
188 |
+
const int h_low_ptr_offset = h_low * h_stride;
|
189 |
+
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
|
190 |
+
const int w_low_ptr_offset = w_low * w_stride;
|
191 |
+
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
|
192 |
+
const int base_ptr = m * channels + c;
|
193 |
+
|
194 |
+
const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
195 |
+
const scalar_t top_grad_value = top_grad * attn_weight;
|
196 |
+
scalar_t grad_h_weight = 0, grad_w_weight = 0;
|
197 |
+
|
198 |
+
scalar_t v1 = 0;
|
199 |
+
if (h_low >= 0 && w_low >= 0)
|
200 |
+
{
|
201 |
+
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
|
202 |
+
v1 = bottom_data[ptr1];
|
203 |
+
grad_h_weight -= hw * v1;
|
204 |
+
grad_w_weight -= hh * v1;
|
205 |
+
atomicAdd(grad_value+ptr1, w1*top_grad_value);
|
206 |
+
}
|
207 |
+
scalar_t v2 = 0;
|
208 |
+
if (h_low >= 0 && w_high <= width - 1)
|
209 |
+
{
|
210 |
+
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
|
211 |
+
v2 = bottom_data[ptr2];
|
212 |
+
grad_h_weight -= lw * v2;
|
213 |
+
grad_w_weight += hh * v2;
|
214 |
+
atomicAdd(grad_value+ptr2, w2*top_grad_value);
|
215 |
+
}
|
216 |
+
scalar_t v3 = 0;
|
217 |
+
if (h_high <= height - 1 && w_low >= 0)
|
218 |
+
{
|
219 |
+
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
|
220 |
+
v3 = bottom_data[ptr3];
|
221 |
+
grad_h_weight += hw * v3;
|
222 |
+
grad_w_weight -= lh * v3;
|
223 |
+
atomicAdd(grad_value+ptr3, w3*top_grad_value);
|
224 |
+
}
|
225 |
+
scalar_t v4 = 0;
|
226 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
227 |
+
{
|
228 |
+
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
|
229 |
+
v4 = bottom_data[ptr4];
|
230 |
+
grad_h_weight += lw * v4;
|
231 |
+
grad_w_weight += lh * v4;
|
232 |
+
atomicAdd(grad_value+ptr4, w4*top_grad_value);
|
233 |
+
}
|
234 |
+
|
235 |
+
const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
236 |
+
atomicAdd(grad_attn_weight, top_grad * val);
|
237 |
+
atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
|
238 |
+
atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
template <typename scalar_t>
|
243 |
+
__global__ void ms_deformable_im2col_gpu_kernel(const int n,
|
244 |
+
const scalar_t *data_value,
|
245 |
+
const int64_t *data_spatial_shapes,
|
246 |
+
const int64_t *data_level_start_index,
|
247 |
+
const scalar_t *data_sampling_loc,
|
248 |
+
const scalar_t *data_attn_weight,
|
249 |
+
const int batch_size,
|
250 |
+
const int spatial_size,
|
251 |
+
const int num_heads,
|
252 |
+
const int channels,
|
253 |
+
const int num_levels,
|
254 |
+
const int num_query,
|
255 |
+
const int num_point,
|
256 |
+
scalar_t *data_col)
|
257 |
+
{
|
258 |
+
CUDA_KERNEL_LOOP(index, n)
|
259 |
+
{
|
260 |
+
int _temp = index;
|
261 |
+
const int c_col = _temp % channels;
|
262 |
+
_temp /= channels;
|
263 |
+
const int sampling_index = _temp;
|
264 |
+
const int m_col = _temp % num_heads;
|
265 |
+
_temp /= num_heads;
|
266 |
+
const int q_col = _temp % num_query;
|
267 |
+
_temp /= num_query;
|
268 |
+
const int b_col = _temp;
|
269 |
+
|
270 |
+
scalar_t *data_col_ptr = data_col + index;
|
271 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
272 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
273 |
+
const int qid_stride = num_heads * channels;
|
274 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
275 |
+
scalar_t col = 0;
|
276 |
+
|
277 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
278 |
+
{
|
279 |
+
const int level_start_id = data_level_start_index[l_col];
|
280 |
+
const int spatial_h_ptr = l_col << 1;
|
281 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
282 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
283 |
+
const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
|
284 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
285 |
+
{
|
286 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
287 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
288 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
289 |
+
|
290 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
291 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
292 |
+
|
293 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
294 |
+
{
|
295 |
+
col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
|
296 |
+
}
|
297 |
+
|
298 |
+
data_weight_ptr += 1;
|
299 |
+
data_loc_w_ptr += 2;
|
300 |
+
}
|
301 |
+
}
|
302 |
+
*data_col_ptr = col;
|
303 |
+
}
|
304 |
+
}
|
305 |
+
|
306 |
+
template <typename scalar_t, unsigned int blockSize>
|
307 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
|
308 |
+
const scalar_t *grad_col,
|
309 |
+
const scalar_t *data_value,
|
310 |
+
const int64_t *data_spatial_shapes,
|
311 |
+
const int64_t *data_level_start_index,
|
312 |
+
const scalar_t *data_sampling_loc,
|
313 |
+
const scalar_t *data_attn_weight,
|
314 |
+
const int batch_size,
|
315 |
+
const int spatial_size,
|
316 |
+
const int num_heads,
|
317 |
+
const int channels,
|
318 |
+
const int num_levels,
|
319 |
+
const int num_query,
|
320 |
+
const int num_point,
|
321 |
+
scalar_t *grad_value,
|
322 |
+
scalar_t *grad_sampling_loc,
|
323 |
+
scalar_t *grad_attn_weight)
|
324 |
+
{
|
325 |
+
CUDA_KERNEL_LOOP(index, n)
|
326 |
+
{
|
327 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
328 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
329 |
+
unsigned int tid = threadIdx.x;
|
330 |
+
int _temp = index;
|
331 |
+
const int c_col = _temp % channels;
|
332 |
+
_temp /= channels;
|
333 |
+
const int sampling_index = _temp;
|
334 |
+
const int m_col = _temp % num_heads;
|
335 |
+
_temp /= num_heads;
|
336 |
+
const int q_col = _temp % num_query;
|
337 |
+
_temp /= num_query;
|
338 |
+
const int b_col = _temp;
|
339 |
+
|
340 |
+
const scalar_t top_grad = grad_col[index];
|
341 |
+
|
342 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
343 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
344 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
345 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
346 |
+
grad_attn_weight += grad_sampling_ptr;
|
347 |
+
const int grad_weight_stride = 1;
|
348 |
+
const int grad_loc_stride = 2;
|
349 |
+
const int qid_stride = num_heads * channels;
|
350 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
351 |
+
|
352 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
353 |
+
{
|
354 |
+
const int level_start_id = data_level_start_index[l_col];
|
355 |
+
const int spatial_h_ptr = l_col << 1;
|
356 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
357 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
358 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
359 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
360 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
361 |
+
|
362 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
363 |
+
{
|
364 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
365 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
366 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
367 |
+
|
368 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
369 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
370 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
371 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
372 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
373 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
374 |
+
{
|
375 |
+
ms_deform_attn_col2im_bilinear(
|
376 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
377 |
+
top_grad, weight, grad_value_ptr,
|
378 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
379 |
+
}
|
380 |
+
|
381 |
+
__syncthreads();
|
382 |
+
if (tid == 0)
|
383 |
+
{
|
384 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
385 |
+
int sid=2;
|
386 |
+
for (unsigned int tid = 1; tid < blockSize; ++tid)
|
387 |
+
{
|
388 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
389 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
390 |
+
_grad_a += cache_grad_attn_weight[tid];
|
391 |
+
sid += 2;
|
392 |
+
}
|
393 |
+
|
394 |
+
|
395 |
+
*grad_sampling_loc = _grad_w;
|
396 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
397 |
+
*grad_attn_weight = _grad_a;
|
398 |
+
}
|
399 |
+
__syncthreads();
|
400 |
+
|
401 |
+
data_weight_ptr += 1;
|
402 |
+
data_loc_w_ptr += 2;
|
403 |
+
grad_attn_weight += grad_weight_stride;
|
404 |
+
grad_sampling_loc += grad_loc_stride;
|
405 |
+
}
|
406 |
+
}
|
407 |
+
}
|
408 |
+
}
|
409 |
+
|
410 |
+
|
411 |
+
template <typename scalar_t, unsigned int blockSize>
|
412 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
|
413 |
+
const scalar_t *grad_col,
|
414 |
+
const scalar_t *data_value,
|
415 |
+
const int64_t *data_spatial_shapes,
|
416 |
+
const int64_t *data_level_start_index,
|
417 |
+
const scalar_t *data_sampling_loc,
|
418 |
+
const scalar_t *data_attn_weight,
|
419 |
+
const int batch_size,
|
420 |
+
const int spatial_size,
|
421 |
+
const int num_heads,
|
422 |
+
const int channels,
|
423 |
+
const int num_levels,
|
424 |
+
const int num_query,
|
425 |
+
const int num_point,
|
426 |
+
scalar_t *grad_value,
|
427 |
+
scalar_t *grad_sampling_loc,
|
428 |
+
scalar_t *grad_attn_weight)
|
429 |
+
{
|
430 |
+
CUDA_KERNEL_LOOP(index, n)
|
431 |
+
{
|
432 |
+
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
|
433 |
+
__shared__ scalar_t cache_grad_attn_weight[blockSize];
|
434 |
+
unsigned int tid = threadIdx.x;
|
435 |
+
int _temp = index;
|
436 |
+
const int c_col = _temp % channels;
|
437 |
+
_temp /= channels;
|
438 |
+
const int sampling_index = _temp;
|
439 |
+
const int m_col = _temp % num_heads;
|
440 |
+
_temp /= num_heads;
|
441 |
+
const int q_col = _temp % num_query;
|
442 |
+
_temp /= num_query;
|
443 |
+
const int b_col = _temp;
|
444 |
+
|
445 |
+
const scalar_t top_grad = grad_col[index];
|
446 |
+
|
447 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
448 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
449 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
450 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
451 |
+
grad_attn_weight += grad_sampling_ptr;
|
452 |
+
const int grad_weight_stride = 1;
|
453 |
+
const int grad_loc_stride = 2;
|
454 |
+
const int qid_stride = num_heads * channels;
|
455 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
456 |
+
|
457 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
458 |
+
{
|
459 |
+
const int level_start_id = data_level_start_index[l_col];
|
460 |
+
const int spatial_h_ptr = l_col << 1;
|
461 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
462 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
463 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
464 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
465 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
466 |
+
|
467 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
468 |
+
{
|
469 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
470 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
471 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
472 |
+
|
473 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
474 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
475 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
476 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
477 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
478 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
479 |
+
{
|
480 |
+
ms_deform_attn_col2im_bilinear(
|
481 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
482 |
+
top_grad, weight, grad_value_ptr,
|
483 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
484 |
+
}
|
485 |
+
|
486 |
+
__syncthreads();
|
487 |
+
|
488 |
+
for (unsigned int s=blockSize/2; s>0; s>>=1)
|
489 |
+
{
|
490 |
+
if (tid < s) {
|
491 |
+
const unsigned int xid1 = tid << 1;
|
492 |
+
const unsigned int xid2 = (tid + s) << 1;
|
493 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
494 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
495 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
496 |
+
}
|
497 |
+
__syncthreads();
|
498 |
+
}
|
499 |
+
|
500 |
+
if (tid == 0)
|
501 |
+
{
|
502 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
503 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
504 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
505 |
+
}
|
506 |
+
__syncthreads();
|
507 |
+
|
508 |
+
data_weight_ptr += 1;
|
509 |
+
data_loc_w_ptr += 2;
|
510 |
+
grad_attn_weight += grad_weight_stride;
|
511 |
+
grad_sampling_loc += grad_loc_stride;
|
512 |
+
}
|
513 |
+
}
|
514 |
+
}
|
515 |
+
}
|
516 |
+
|
517 |
+
|
518 |
+
template <typename scalar_t>
|
519 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
|
520 |
+
const scalar_t *grad_col,
|
521 |
+
const scalar_t *data_value,
|
522 |
+
const int64_t *data_spatial_shapes,
|
523 |
+
const int64_t *data_level_start_index,
|
524 |
+
const scalar_t *data_sampling_loc,
|
525 |
+
const scalar_t *data_attn_weight,
|
526 |
+
const int batch_size,
|
527 |
+
const int spatial_size,
|
528 |
+
const int num_heads,
|
529 |
+
const int channels,
|
530 |
+
const int num_levels,
|
531 |
+
const int num_query,
|
532 |
+
const int num_point,
|
533 |
+
scalar_t *grad_value,
|
534 |
+
scalar_t *grad_sampling_loc,
|
535 |
+
scalar_t *grad_attn_weight)
|
536 |
+
{
|
537 |
+
CUDA_KERNEL_LOOP(index, n)
|
538 |
+
{
|
539 |
+
extern __shared__ int _s[];
|
540 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
541 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
542 |
+
unsigned int tid = threadIdx.x;
|
543 |
+
int _temp = index;
|
544 |
+
const int c_col = _temp % channels;
|
545 |
+
_temp /= channels;
|
546 |
+
const int sampling_index = _temp;
|
547 |
+
const int m_col = _temp % num_heads;
|
548 |
+
_temp /= num_heads;
|
549 |
+
const int q_col = _temp % num_query;
|
550 |
+
_temp /= num_query;
|
551 |
+
const int b_col = _temp;
|
552 |
+
|
553 |
+
const scalar_t top_grad = grad_col[index];
|
554 |
+
|
555 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
556 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
557 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
558 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
559 |
+
grad_attn_weight += grad_sampling_ptr;
|
560 |
+
const int grad_weight_stride = 1;
|
561 |
+
const int grad_loc_stride = 2;
|
562 |
+
const int qid_stride = num_heads * channels;
|
563 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
564 |
+
|
565 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
566 |
+
{
|
567 |
+
const int level_start_id = data_level_start_index[l_col];
|
568 |
+
const int spatial_h_ptr = l_col << 1;
|
569 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
570 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
571 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
572 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
573 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
574 |
+
|
575 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
576 |
+
{
|
577 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
578 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
579 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
580 |
+
|
581 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
582 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
583 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
584 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
585 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
586 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
587 |
+
{
|
588 |
+
ms_deform_attn_col2im_bilinear(
|
589 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
590 |
+
top_grad, weight, grad_value_ptr,
|
591 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
592 |
+
}
|
593 |
+
|
594 |
+
__syncthreads();
|
595 |
+
if (tid == 0)
|
596 |
+
{
|
597 |
+
scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
|
598 |
+
int sid=2;
|
599 |
+
for (unsigned int tid = 1; tid < blockDim.x; ++tid)
|
600 |
+
{
|
601 |
+
_grad_w += cache_grad_sampling_loc[sid];
|
602 |
+
_grad_h += cache_grad_sampling_loc[sid + 1];
|
603 |
+
_grad_a += cache_grad_attn_weight[tid];
|
604 |
+
sid += 2;
|
605 |
+
}
|
606 |
+
|
607 |
+
|
608 |
+
*grad_sampling_loc = _grad_w;
|
609 |
+
*(grad_sampling_loc + 1) = _grad_h;
|
610 |
+
*grad_attn_weight = _grad_a;
|
611 |
+
}
|
612 |
+
__syncthreads();
|
613 |
+
|
614 |
+
data_weight_ptr += 1;
|
615 |
+
data_loc_w_ptr += 2;
|
616 |
+
grad_attn_weight += grad_weight_stride;
|
617 |
+
grad_sampling_loc += grad_loc_stride;
|
618 |
+
}
|
619 |
+
}
|
620 |
+
}
|
621 |
+
}
|
622 |
+
|
623 |
+
template <typename scalar_t>
|
624 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
|
625 |
+
const scalar_t *grad_col,
|
626 |
+
const scalar_t *data_value,
|
627 |
+
const int64_t *data_spatial_shapes,
|
628 |
+
const int64_t *data_level_start_index,
|
629 |
+
const scalar_t *data_sampling_loc,
|
630 |
+
const scalar_t *data_attn_weight,
|
631 |
+
const int batch_size,
|
632 |
+
const int spatial_size,
|
633 |
+
const int num_heads,
|
634 |
+
const int channels,
|
635 |
+
const int num_levels,
|
636 |
+
const int num_query,
|
637 |
+
const int num_point,
|
638 |
+
scalar_t *grad_value,
|
639 |
+
scalar_t *grad_sampling_loc,
|
640 |
+
scalar_t *grad_attn_weight)
|
641 |
+
{
|
642 |
+
CUDA_KERNEL_LOOP(index, n)
|
643 |
+
{
|
644 |
+
extern __shared__ int _s[];
|
645 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
646 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
647 |
+
unsigned int tid = threadIdx.x;
|
648 |
+
int _temp = index;
|
649 |
+
const int c_col = _temp % channels;
|
650 |
+
_temp /= channels;
|
651 |
+
const int sampling_index = _temp;
|
652 |
+
const int m_col = _temp % num_heads;
|
653 |
+
_temp /= num_heads;
|
654 |
+
const int q_col = _temp % num_query;
|
655 |
+
_temp /= num_query;
|
656 |
+
const int b_col = _temp;
|
657 |
+
|
658 |
+
const scalar_t top_grad = grad_col[index];
|
659 |
+
|
660 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
661 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
662 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
663 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
664 |
+
grad_attn_weight += grad_sampling_ptr;
|
665 |
+
const int grad_weight_stride = 1;
|
666 |
+
const int grad_loc_stride = 2;
|
667 |
+
const int qid_stride = num_heads * channels;
|
668 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
669 |
+
|
670 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
671 |
+
{
|
672 |
+
const int level_start_id = data_level_start_index[l_col];
|
673 |
+
const int spatial_h_ptr = l_col << 1;
|
674 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
675 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
676 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
677 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
678 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
679 |
+
|
680 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
681 |
+
{
|
682 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
683 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
684 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
685 |
+
|
686 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
687 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
688 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
689 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
690 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
691 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
692 |
+
{
|
693 |
+
ms_deform_attn_col2im_bilinear(
|
694 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
695 |
+
top_grad, weight, grad_value_ptr,
|
696 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
697 |
+
}
|
698 |
+
|
699 |
+
__syncthreads();
|
700 |
+
|
701 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
702 |
+
{
|
703 |
+
if (tid < s) {
|
704 |
+
const unsigned int xid1 = tid << 1;
|
705 |
+
const unsigned int xid2 = (tid + s) << 1;
|
706 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
707 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
708 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
709 |
+
if (tid + (s << 1) < spre)
|
710 |
+
{
|
711 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
712 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
713 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
714 |
+
}
|
715 |
+
}
|
716 |
+
__syncthreads();
|
717 |
+
}
|
718 |
+
|
719 |
+
if (tid == 0)
|
720 |
+
{
|
721 |
+
*grad_sampling_loc = cache_grad_sampling_loc[0];
|
722 |
+
*(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
|
723 |
+
*grad_attn_weight = cache_grad_attn_weight[0];
|
724 |
+
}
|
725 |
+
__syncthreads();
|
726 |
+
|
727 |
+
data_weight_ptr += 1;
|
728 |
+
data_loc_w_ptr += 2;
|
729 |
+
grad_attn_weight += grad_weight_stride;
|
730 |
+
grad_sampling_loc += grad_loc_stride;
|
731 |
+
}
|
732 |
+
}
|
733 |
+
}
|
734 |
+
}
|
735 |
+
|
736 |
+
template <typename scalar_t>
|
737 |
+
__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
|
738 |
+
const scalar_t *grad_col,
|
739 |
+
const scalar_t *data_value,
|
740 |
+
const int64_t *data_spatial_shapes,
|
741 |
+
const int64_t *data_level_start_index,
|
742 |
+
const scalar_t *data_sampling_loc,
|
743 |
+
const scalar_t *data_attn_weight,
|
744 |
+
const int batch_size,
|
745 |
+
const int spatial_size,
|
746 |
+
const int num_heads,
|
747 |
+
const int channels,
|
748 |
+
const int num_levels,
|
749 |
+
const int num_query,
|
750 |
+
const int num_point,
|
751 |
+
scalar_t *grad_value,
|
752 |
+
scalar_t *grad_sampling_loc,
|
753 |
+
scalar_t *grad_attn_weight)
|
754 |
+
{
|
755 |
+
CUDA_KERNEL_LOOP(index, n)
|
756 |
+
{
|
757 |
+
extern __shared__ int _s[];
|
758 |
+
scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
|
759 |
+
scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
|
760 |
+
unsigned int tid = threadIdx.x;
|
761 |
+
int _temp = index;
|
762 |
+
const int c_col = _temp % channels;
|
763 |
+
_temp /= channels;
|
764 |
+
const int sampling_index = _temp;
|
765 |
+
const int m_col = _temp % num_heads;
|
766 |
+
_temp /= num_heads;
|
767 |
+
const int q_col = _temp % num_query;
|
768 |
+
_temp /= num_query;
|
769 |
+
const int b_col = _temp;
|
770 |
+
|
771 |
+
const scalar_t top_grad = grad_col[index];
|
772 |
+
|
773 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
774 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
775 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
776 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
777 |
+
grad_attn_weight += grad_sampling_ptr;
|
778 |
+
const int grad_weight_stride = 1;
|
779 |
+
const int grad_loc_stride = 2;
|
780 |
+
const int qid_stride = num_heads * channels;
|
781 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
782 |
+
|
783 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
784 |
+
{
|
785 |
+
const int level_start_id = data_level_start_index[l_col];
|
786 |
+
const int spatial_h_ptr = l_col << 1;
|
787 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
788 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
789 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
790 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
791 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
792 |
+
|
793 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
794 |
+
{
|
795 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
796 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
797 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
798 |
+
|
799 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
800 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
801 |
+
*(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
|
802 |
+
*(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
|
803 |
+
*(cache_grad_attn_weight+threadIdx.x)=0;
|
804 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
805 |
+
{
|
806 |
+
ms_deform_attn_col2im_bilinear(
|
807 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
808 |
+
top_grad, weight, grad_value_ptr,
|
809 |
+
cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
|
810 |
+
}
|
811 |
+
|
812 |
+
__syncthreads();
|
813 |
+
|
814 |
+
for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
|
815 |
+
{
|
816 |
+
if (tid < s) {
|
817 |
+
const unsigned int xid1 = tid << 1;
|
818 |
+
const unsigned int xid2 = (tid + s) << 1;
|
819 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
|
820 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
|
821 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
|
822 |
+
if (tid + (s << 1) < spre)
|
823 |
+
{
|
824 |
+
cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
|
825 |
+
cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
|
826 |
+
cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
|
827 |
+
}
|
828 |
+
}
|
829 |
+
__syncthreads();
|
830 |
+
}
|
831 |
+
|
832 |
+
if (tid == 0)
|
833 |
+
{
|
834 |
+
atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
|
835 |
+
atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
|
836 |
+
atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
|
837 |
+
}
|
838 |
+
__syncthreads();
|
839 |
+
|
840 |
+
data_weight_ptr += 1;
|
841 |
+
data_loc_w_ptr += 2;
|
842 |
+
grad_attn_weight += grad_weight_stride;
|
843 |
+
grad_sampling_loc += grad_loc_stride;
|
844 |
+
}
|
845 |
+
}
|
846 |
+
}
|
847 |
+
}
|
848 |
+
|
849 |
+
|
850 |
+
template <typename scalar_t>
|
851 |
+
__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
|
852 |
+
const scalar_t *grad_col,
|
853 |
+
const scalar_t *data_value,
|
854 |
+
const int64_t *data_spatial_shapes,
|
855 |
+
const int64_t *data_level_start_index,
|
856 |
+
const scalar_t *data_sampling_loc,
|
857 |
+
const scalar_t *data_attn_weight,
|
858 |
+
const int batch_size,
|
859 |
+
const int spatial_size,
|
860 |
+
const int num_heads,
|
861 |
+
const int channels,
|
862 |
+
const int num_levels,
|
863 |
+
const int num_query,
|
864 |
+
const int num_point,
|
865 |
+
scalar_t *grad_value,
|
866 |
+
scalar_t *grad_sampling_loc,
|
867 |
+
scalar_t *grad_attn_weight)
|
868 |
+
{
|
869 |
+
CUDA_KERNEL_LOOP(index, n)
|
870 |
+
{
|
871 |
+
int _temp = index;
|
872 |
+
const int c_col = _temp % channels;
|
873 |
+
_temp /= channels;
|
874 |
+
const int sampling_index = _temp;
|
875 |
+
const int m_col = _temp % num_heads;
|
876 |
+
_temp /= num_heads;
|
877 |
+
const int q_col = _temp % num_query;
|
878 |
+
_temp /= num_query;
|
879 |
+
const int b_col = _temp;
|
880 |
+
|
881 |
+
const scalar_t top_grad = grad_col[index];
|
882 |
+
|
883 |
+
int data_weight_ptr = sampling_index * num_levels * num_point;
|
884 |
+
int data_loc_w_ptr = data_weight_ptr << 1;
|
885 |
+
const int grad_sampling_ptr = data_weight_ptr;
|
886 |
+
grad_sampling_loc += grad_sampling_ptr << 1;
|
887 |
+
grad_attn_weight += grad_sampling_ptr;
|
888 |
+
const int grad_weight_stride = 1;
|
889 |
+
const int grad_loc_stride = 2;
|
890 |
+
const int qid_stride = num_heads * channels;
|
891 |
+
const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
|
892 |
+
|
893 |
+
for (int l_col=0; l_col < num_levels; ++l_col)
|
894 |
+
{
|
895 |
+
const int level_start_id = data_level_start_index[l_col];
|
896 |
+
const int spatial_h_ptr = l_col << 1;
|
897 |
+
const int spatial_h = data_spatial_shapes[spatial_h_ptr];
|
898 |
+
const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
|
899 |
+
const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
|
900 |
+
const scalar_t *data_value_ptr = data_value + value_ptr_offset;
|
901 |
+
scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
|
902 |
+
|
903 |
+
for (int p_col=0; p_col < num_point; ++p_col)
|
904 |
+
{
|
905 |
+
const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
|
906 |
+
const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
|
907 |
+
const scalar_t weight = data_attn_weight[data_weight_ptr];
|
908 |
+
|
909 |
+
const scalar_t h_im = loc_h * spatial_h - 0.5;
|
910 |
+
const scalar_t w_im = loc_w * spatial_w - 0.5;
|
911 |
+
if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
|
912 |
+
{
|
913 |
+
ms_deform_attn_col2im_bilinear_gm(
|
914 |
+
data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
|
915 |
+
top_grad, weight, grad_value_ptr,
|
916 |
+
grad_sampling_loc, grad_attn_weight);
|
917 |
+
}
|
918 |
+
data_weight_ptr += 1;
|
919 |
+
data_loc_w_ptr += 2;
|
920 |
+
grad_attn_weight += grad_weight_stride;
|
921 |
+
grad_sampling_loc += grad_loc_stride;
|
922 |
+
}
|
923 |
+
}
|
924 |
+
}
|
925 |
+
}
|
926 |
+
|
927 |
+
|
928 |
+
template <typename scalar_t>
|
929 |
+
void ms_deformable_im2col_cuda(cudaStream_t stream,
|
930 |
+
const scalar_t* data_value,
|
931 |
+
const int64_t* data_spatial_shapes,
|
932 |
+
const int64_t* data_level_start_index,
|
933 |
+
const scalar_t* data_sampling_loc,
|
934 |
+
const scalar_t* data_attn_weight,
|
935 |
+
const int batch_size,
|
936 |
+
const int spatial_size,
|
937 |
+
const int num_heads,
|
938 |
+
const int channels,
|
939 |
+
const int num_levels,
|
940 |
+
const int num_query,
|
941 |
+
const int num_point,
|
942 |
+
scalar_t* data_col)
|
943 |
+
{
|
944 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
945 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
946 |
+
const int num_threads = CUDA_NUM_THREADS;
|
947 |
+
ms_deformable_im2col_gpu_kernel<scalar_t>
|
948 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
949 |
+
0, stream>>>(
|
950 |
+
num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
|
951 |
+
batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
|
952 |
+
|
953 |
+
cudaError_t err = cudaGetLastError();
|
954 |
+
if (err != cudaSuccess)
|
955 |
+
{
|
956 |
+
printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
957 |
+
}
|
958 |
+
|
959 |
+
}
|
960 |
+
|
961 |
+
template <typename scalar_t>
|
962 |
+
void ms_deformable_col2im_cuda(cudaStream_t stream,
|
963 |
+
const scalar_t* grad_col,
|
964 |
+
const scalar_t* data_value,
|
965 |
+
const int64_t * data_spatial_shapes,
|
966 |
+
const int64_t * data_level_start_index,
|
967 |
+
const scalar_t * data_sampling_loc,
|
968 |
+
const scalar_t * data_attn_weight,
|
969 |
+
const int batch_size,
|
970 |
+
const int spatial_size,
|
971 |
+
const int num_heads,
|
972 |
+
const int channels,
|
973 |
+
const int num_levels,
|
974 |
+
const int num_query,
|
975 |
+
const int num_point,
|
976 |
+
scalar_t* grad_value,
|
977 |
+
scalar_t* grad_sampling_loc,
|
978 |
+
scalar_t* grad_attn_weight)
|
979 |
+
{
|
980 |
+
const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
|
981 |
+
const int num_kernels = batch_size * num_query * num_heads * channels;
|
982 |
+
const int num_actual_kernels = batch_size * num_query * num_heads * channels;
|
983 |
+
if (channels > 1024)
|
984 |
+
{
|
985 |
+
if ((channels & 1023) == 0)
|
986 |
+
{
|
987 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
|
988 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
989 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
990 |
+
num_kernels,
|
991 |
+
grad_col,
|
992 |
+
data_value,
|
993 |
+
data_spatial_shapes,
|
994 |
+
data_level_start_index,
|
995 |
+
data_sampling_loc,
|
996 |
+
data_attn_weight,
|
997 |
+
batch_size,
|
998 |
+
spatial_size,
|
999 |
+
num_heads,
|
1000 |
+
channels,
|
1001 |
+
num_levels,
|
1002 |
+
num_query,
|
1003 |
+
num_point,
|
1004 |
+
grad_value,
|
1005 |
+
grad_sampling_loc,
|
1006 |
+
grad_attn_weight);
|
1007 |
+
}
|
1008 |
+
else
|
1009 |
+
{
|
1010 |
+
ms_deformable_col2im_gpu_kernel_gm<scalar_t>
|
1011 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1012 |
+
0, stream>>>(
|
1013 |
+
num_kernels,
|
1014 |
+
grad_col,
|
1015 |
+
data_value,
|
1016 |
+
data_spatial_shapes,
|
1017 |
+
data_level_start_index,
|
1018 |
+
data_sampling_loc,
|
1019 |
+
data_attn_weight,
|
1020 |
+
batch_size,
|
1021 |
+
spatial_size,
|
1022 |
+
num_heads,
|
1023 |
+
channels,
|
1024 |
+
num_levels,
|
1025 |
+
num_query,
|
1026 |
+
num_point,
|
1027 |
+
grad_value,
|
1028 |
+
grad_sampling_loc,
|
1029 |
+
grad_attn_weight);
|
1030 |
+
}
|
1031 |
+
}
|
1032 |
+
else{
|
1033 |
+
switch(channels)
|
1034 |
+
{
|
1035 |
+
case 1:
|
1036 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
|
1037 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1038 |
+
0, stream>>>(
|
1039 |
+
num_kernels,
|
1040 |
+
grad_col,
|
1041 |
+
data_value,
|
1042 |
+
data_spatial_shapes,
|
1043 |
+
data_level_start_index,
|
1044 |
+
data_sampling_loc,
|
1045 |
+
data_attn_weight,
|
1046 |
+
batch_size,
|
1047 |
+
spatial_size,
|
1048 |
+
num_heads,
|
1049 |
+
channels,
|
1050 |
+
num_levels,
|
1051 |
+
num_query,
|
1052 |
+
num_point,
|
1053 |
+
grad_value,
|
1054 |
+
grad_sampling_loc,
|
1055 |
+
grad_attn_weight);
|
1056 |
+
break;
|
1057 |
+
case 2:
|
1058 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
|
1059 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1060 |
+
0, stream>>>(
|
1061 |
+
num_kernels,
|
1062 |
+
grad_col,
|
1063 |
+
data_value,
|
1064 |
+
data_spatial_shapes,
|
1065 |
+
data_level_start_index,
|
1066 |
+
data_sampling_loc,
|
1067 |
+
data_attn_weight,
|
1068 |
+
batch_size,
|
1069 |
+
spatial_size,
|
1070 |
+
num_heads,
|
1071 |
+
channels,
|
1072 |
+
num_levels,
|
1073 |
+
num_query,
|
1074 |
+
num_point,
|
1075 |
+
grad_value,
|
1076 |
+
grad_sampling_loc,
|
1077 |
+
grad_attn_weight);
|
1078 |
+
break;
|
1079 |
+
case 4:
|
1080 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
|
1081 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1082 |
+
0, stream>>>(
|
1083 |
+
num_kernels,
|
1084 |
+
grad_col,
|
1085 |
+
data_value,
|
1086 |
+
data_spatial_shapes,
|
1087 |
+
data_level_start_index,
|
1088 |
+
data_sampling_loc,
|
1089 |
+
data_attn_weight,
|
1090 |
+
batch_size,
|
1091 |
+
spatial_size,
|
1092 |
+
num_heads,
|
1093 |
+
channels,
|
1094 |
+
num_levels,
|
1095 |
+
num_query,
|
1096 |
+
num_point,
|
1097 |
+
grad_value,
|
1098 |
+
grad_sampling_loc,
|
1099 |
+
grad_attn_weight);
|
1100 |
+
break;
|
1101 |
+
case 8:
|
1102 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
|
1103 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1104 |
+
0, stream>>>(
|
1105 |
+
num_kernels,
|
1106 |
+
grad_col,
|
1107 |
+
data_value,
|
1108 |
+
data_spatial_shapes,
|
1109 |
+
data_level_start_index,
|
1110 |
+
data_sampling_loc,
|
1111 |
+
data_attn_weight,
|
1112 |
+
batch_size,
|
1113 |
+
spatial_size,
|
1114 |
+
num_heads,
|
1115 |
+
channels,
|
1116 |
+
num_levels,
|
1117 |
+
num_query,
|
1118 |
+
num_point,
|
1119 |
+
grad_value,
|
1120 |
+
grad_sampling_loc,
|
1121 |
+
grad_attn_weight);
|
1122 |
+
break;
|
1123 |
+
case 16:
|
1124 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
|
1125 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1126 |
+
0, stream>>>(
|
1127 |
+
num_kernels,
|
1128 |
+
grad_col,
|
1129 |
+
data_value,
|
1130 |
+
data_spatial_shapes,
|
1131 |
+
data_level_start_index,
|
1132 |
+
data_sampling_loc,
|
1133 |
+
data_attn_weight,
|
1134 |
+
batch_size,
|
1135 |
+
spatial_size,
|
1136 |
+
num_heads,
|
1137 |
+
channels,
|
1138 |
+
num_levels,
|
1139 |
+
num_query,
|
1140 |
+
num_point,
|
1141 |
+
grad_value,
|
1142 |
+
grad_sampling_loc,
|
1143 |
+
grad_attn_weight);
|
1144 |
+
break;
|
1145 |
+
case 32:
|
1146 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
|
1147 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1148 |
+
0, stream>>>(
|
1149 |
+
num_kernels,
|
1150 |
+
grad_col,
|
1151 |
+
data_value,
|
1152 |
+
data_spatial_shapes,
|
1153 |
+
data_level_start_index,
|
1154 |
+
data_sampling_loc,
|
1155 |
+
data_attn_weight,
|
1156 |
+
batch_size,
|
1157 |
+
spatial_size,
|
1158 |
+
num_heads,
|
1159 |
+
channels,
|
1160 |
+
num_levels,
|
1161 |
+
num_query,
|
1162 |
+
num_point,
|
1163 |
+
grad_value,
|
1164 |
+
grad_sampling_loc,
|
1165 |
+
grad_attn_weight);
|
1166 |
+
break;
|
1167 |
+
case 64:
|
1168 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
|
1169 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1170 |
+
0, stream>>>(
|
1171 |
+
num_kernels,
|
1172 |
+
grad_col,
|
1173 |
+
data_value,
|
1174 |
+
data_spatial_shapes,
|
1175 |
+
data_level_start_index,
|
1176 |
+
data_sampling_loc,
|
1177 |
+
data_attn_weight,
|
1178 |
+
batch_size,
|
1179 |
+
spatial_size,
|
1180 |
+
num_heads,
|
1181 |
+
channels,
|
1182 |
+
num_levels,
|
1183 |
+
num_query,
|
1184 |
+
num_point,
|
1185 |
+
grad_value,
|
1186 |
+
grad_sampling_loc,
|
1187 |
+
grad_attn_weight);
|
1188 |
+
break;
|
1189 |
+
case 128:
|
1190 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
|
1191 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1192 |
+
0, stream>>>(
|
1193 |
+
num_kernels,
|
1194 |
+
grad_col,
|
1195 |
+
data_value,
|
1196 |
+
data_spatial_shapes,
|
1197 |
+
data_level_start_index,
|
1198 |
+
data_sampling_loc,
|
1199 |
+
data_attn_weight,
|
1200 |
+
batch_size,
|
1201 |
+
spatial_size,
|
1202 |
+
num_heads,
|
1203 |
+
channels,
|
1204 |
+
num_levels,
|
1205 |
+
num_query,
|
1206 |
+
num_point,
|
1207 |
+
grad_value,
|
1208 |
+
grad_sampling_loc,
|
1209 |
+
grad_attn_weight);
|
1210 |
+
break;
|
1211 |
+
case 256:
|
1212 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
|
1213 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1214 |
+
0, stream>>>(
|
1215 |
+
num_kernels,
|
1216 |
+
grad_col,
|
1217 |
+
data_value,
|
1218 |
+
data_spatial_shapes,
|
1219 |
+
data_level_start_index,
|
1220 |
+
data_sampling_loc,
|
1221 |
+
data_attn_weight,
|
1222 |
+
batch_size,
|
1223 |
+
spatial_size,
|
1224 |
+
num_heads,
|
1225 |
+
channels,
|
1226 |
+
num_levels,
|
1227 |
+
num_query,
|
1228 |
+
num_point,
|
1229 |
+
grad_value,
|
1230 |
+
grad_sampling_loc,
|
1231 |
+
grad_attn_weight);
|
1232 |
+
break;
|
1233 |
+
case 512:
|
1234 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
|
1235 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1236 |
+
0, stream>>>(
|
1237 |
+
num_kernels,
|
1238 |
+
grad_col,
|
1239 |
+
data_value,
|
1240 |
+
data_spatial_shapes,
|
1241 |
+
data_level_start_index,
|
1242 |
+
data_sampling_loc,
|
1243 |
+
data_attn_weight,
|
1244 |
+
batch_size,
|
1245 |
+
spatial_size,
|
1246 |
+
num_heads,
|
1247 |
+
channels,
|
1248 |
+
num_levels,
|
1249 |
+
num_query,
|
1250 |
+
num_point,
|
1251 |
+
grad_value,
|
1252 |
+
grad_sampling_loc,
|
1253 |
+
grad_attn_weight);
|
1254 |
+
break;
|
1255 |
+
case 1024:
|
1256 |
+
ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
|
1257 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1258 |
+
0, stream>>>(
|
1259 |
+
num_kernels,
|
1260 |
+
grad_col,
|
1261 |
+
data_value,
|
1262 |
+
data_spatial_shapes,
|
1263 |
+
data_level_start_index,
|
1264 |
+
data_sampling_loc,
|
1265 |
+
data_attn_weight,
|
1266 |
+
batch_size,
|
1267 |
+
spatial_size,
|
1268 |
+
num_heads,
|
1269 |
+
channels,
|
1270 |
+
num_levels,
|
1271 |
+
num_query,
|
1272 |
+
num_point,
|
1273 |
+
grad_value,
|
1274 |
+
grad_sampling_loc,
|
1275 |
+
grad_attn_weight);
|
1276 |
+
break;
|
1277 |
+
default:
|
1278 |
+
if (channels < 64)
|
1279 |
+
{
|
1280 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
|
1281 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1282 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
1283 |
+
num_kernels,
|
1284 |
+
grad_col,
|
1285 |
+
data_value,
|
1286 |
+
data_spatial_shapes,
|
1287 |
+
data_level_start_index,
|
1288 |
+
data_sampling_loc,
|
1289 |
+
data_attn_weight,
|
1290 |
+
batch_size,
|
1291 |
+
spatial_size,
|
1292 |
+
num_heads,
|
1293 |
+
channels,
|
1294 |
+
num_levels,
|
1295 |
+
num_query,
|
1296 |
+
num_point,
|
1297 |
+
grad_value,
|
1298 |
+
grad_sampling_loc,
|
1299 |
+
grad_attn_weight);
|
1300 |
+
}
|
1301 |
+
else
|
1302 |
+
{
|
1303 |
+
ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
|
1304 |
+
<<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
|
1305 |
+
num_threads*3*sizeof(scalar_t), stream>>>(
|
1306 |
+
num_kernels,
|
1307 |
+
grad_col,
|
1308 |
+
data_value,
|
1309 |
+
data_spatial_shapes,
|
1310 |
+
data_level_start_index,
|
1311 |
+
data_sampling_loc,
|
1312 |
+
data_attn_weight,
|
1313 |
+
batch_size,
|
1314 |
+
spatial_size,
|
1315 |
+
num_heads,
|
1316 |
+
channels,
|
1317 |
+
num_levels,
|
1318 |
+
num_query,
|
1319 |
+
num_point,
|
1320 |
+
grad_value,
|
1321 |
+
grad_sampling_loc,
|
1322 |
+
grad_attn_weight);
|
1323 |
+
}
|
1324 |
+
}
|
1325 |
+
}
|
1326 |
+
cudaError_t err = cudaGetLastError();
|
1327 |
+
if (err != cudaSuccess)
|
1328 |
+
{
|
1329 |
+
printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
1330 |
+
}
|
1331 |
+
|
1332 |
+
}
|
imcui/third_party/rdd/RDD/models/ops/src/ms_deform_attn.h
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
/*!
|
12 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
13 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
14 |
+
*/
|
15 |
+
|
16 |
+
#pragma once
|
17 |
+
|
18 |
+
#include "cpu/ms_deform_attn_cpu.h"
|
19 |
+
|
20 |
+
#ifdef WITH_CUDA
|
21 |
+
#include "cuda/ms_deform_attn_cuda.h"
|
22 |
+
#endif
|
23 |
+
|
24 |
+
|
25 |
+
at::Tensor
|
26 |
+
ms_deform_attn_forward(
|
27 |
+
const at::Tensor &value,
|
28 |
+
const at::Tensor &spatial_shapes,
|
29 |
+
const at::Tensor &level_start_index,
|
30 |
+
const at::Tensor &sampling_loc,
|
31 |
+
const at::Tensor &attn_weight,
|
32 |
+
const int im2col_step)
|
33 |
+
{
|
34 |
+
if (value.type().is_cuda())
|
35 |
+
{
|
36 |
+
#ifdef WITH_CUDA
|
37 |
+
return ms_deform_attn_cuda_forward(
|
38 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
39 |
+
#else
|
40 |
+
AT_ERROR("Not compiled with GPU support");
|
41 |
+
#endif
|
42 |
+
}
|
43 |
+
AT_ERROR("Not implemented on the CPU");
|
44 |
+
}
|
45 |
+
|
46 |
+
std::vector<at::Tensor>
|
47 |
+
ms_deform_attn_backward(
|
48 |
+
const at::Tensor &value,
|
49 |
+
const at::Tensor &spatial_shapes,
|
50 |
+
const at::Tensor &level_start_index,
|
51 |
+
const at::Tensor &sampling_loc,
|
52 |
+
const at::Tensor &attn_weight,
|
53 |
+
const at::Tensor &grad_output,
|
54 |
+
const int im2col_step)
|
55 |
+
{
|
56 |
+
if (value.type().is_cuda())
|
57 |
+
{
|
58 |
+
#ifdef WITH_CUDA
|
59 |
+
return ms_deform_attn_cuda_backward(
|
60 |
+
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
61 |
+
#else
|
62 |
+
AT_ERROR("Not compiled with GPU support");
|
63 |
+
#endif
|
64 |
+
}
|
65 |
+
AT_ERROR("Not implemented on the CPU");
|
66 |
+
}
|
67 |
+
|
imcui/third_party/rdd/RDD/models/ops/src/vision.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
**************************************************************************************************
|
3 |
+
* Deformable DETR
|
4 |
+
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
5 |
+
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
6 |
+
**************************************************************************************************
|
7 |
+
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
8 |
+
**************************************************************************************************
|
9 |
+
*/
|
10 |
+
|
11 |
+
/*!
|
12 |
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
13 |
+
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
14 |
+
*/
|
15 |
+
|
16 |
+
#include "ms_deform_attn.h"
|
17 |
+
|
18 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
19 |
+
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
20 |
+
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
21 |
+
}
|
imcui/third_party/rdd/RDD/models/ops/test.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------------------
|
2 |
+
# Deformable DETR
|
3 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------------------
|
6 |
+
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
7 |
+
# ------------------------------------------------------------------------------------------------
|
8 |
+
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
10 |
+
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
11 |
+
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import print_function
|
14 |
+
from __future__ import division
|
15 |
+
|
16 |
+
import time
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.autograd import gradcheck
|
20 |
+
|
21 |
+
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
22 |
+
|
23 |
+
|
24 |
+
N, M, D = 1, 2, 2
|
25 |
+
Lq, L, P = 2, 2, 2
|
26 |
+
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
27 |
+
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
|
28 |
+
S = sum([(H*W).item() for H, W in shapes])
|
29 |
+
|
30 |
+
|
31 |
+
torch.manual_seed(3)
|
32 |
+
|
33 |
+
|
34 |
+
@torch.no_grad()
|
35 |
+
def check_forward_equal_with_pytorch_double():
|
36 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
37 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
38 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
39 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
40 |
+
im2col_step = 2
|
41 |
+
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
|
42 |
+
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
|
43 |
+
fwdok = torch.allclose(output_cuda, output_pytorch)
|
44 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
45 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
46 |
+
|
47 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
48 |
+
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def check_forward_equal_with_pytorch_float():
|
52 |
+
value = torch.rand(N, S, M, D).cuda() * 0.01
|
53 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
54 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
55 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
56 |
+
im2col_step = 2
|
57 |
+
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
|
58 |
+
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
|
59 |
+
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
60 |
+
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
61 |
+
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
62 |
+
|
63 |
+
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
64 |
+
|
65 |
+
|
66 |
+
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
|
67 |
+
|
68 |
+
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
69 |
+
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
70 |
+
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
71 |
+
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
72 |
+
im2col_step = 2
|
73 |
+
func = MSDeformAttnFunction.apply
|
74 |
+
|
75 |
+
value.requires_grad = grad_value
|
76 |
+
sampling_locations.requires_grad = grad_sampling_loc
|
77 |
+
attention_weights.requires_grad = grad_attn_weight
|
78 |
+
|
79 |
+
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
|
80 |
+
|
81 |
+
print(f'* {gradok} check_gradient_numerical(D={channels})')
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
check_forward_equal_with_pytorch_double()
|
86 |
+
check_forward_equal_with_pytorch_float()
|
87 |
+
|
88 |
+
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
89 |
+
check_gradient_numerical(channels, True, True, True)
|
90 |
+
|
91 |
+
|
92 |
+
|
imcui/third_party/rdd/RDD/models/position_encoding.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from ..utils.misc import NestedTensor
|
5 |
+
|
6 |
+
class PositionEmbeddingSine(nn.Module):
|
7 |
+
"""
|
8 |
+
This is a more standard version of the position embedding, very similar to the one
|
9 |
+
used by the Attention is all you need paper, generalized to work on images.
|
10 |
+
"""
|
11 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
12 |
+
super().__init__()
|
13 |
+
self.num_pos_feats = num_pos_feats
|
14 |
+
self.temperature = temperature
|
15 |
+
self.normalize = normalize
|
16 |
+
if scale is not None and normalize is False:
|
17 |
+
raise ValueError("normalize should be True if scale is passed")
|
18 |
+
if scale is None:
|
19 |
+
scale = 2 * math.pi
|
20 |
+
self.scale = scale
|
21 |
+
|
22 |
+
def forward(self, tensor_list: NestedTensor):
|
23 |
+
x = tensor_list.tensors
|
24 |
+
mask = tensor_list.mask
|
25 |
+
assert mask is not None
|
26 |
+
not_mask = ~mask
|
27 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
28 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
29 |
+
if self.normalize:
|
30 |
+
eps = 1e-6
|
31 |
+
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
32 |
+
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
33 |
+
|
34 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
35 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
36 |
+
|
37 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
38 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
39 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
40 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
41 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
42 |
+
return pos
|
43 |
+
|
44 |
+
def build_position_encoding(config):
|
45 |
+
N_steps = config['hidden_dim'] // 2
|
46 |
+
# TODO find a better way of exposing other arguments
|
47 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
48 |
+
return position_embedding
|
imcui/third_party/rdd/RDD/models/soft_detect.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ALIKE: https://github.com/Shiaoming/ALIKE
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
# coordinates system
|
9 |
+
# ------------------------------> [ x: range=-1.0~1.0; w: range=0~W ]
|
10 |
+
# | -----------------------------
|
11 |
+
# | | |
|
12 |
+
# | | |
|
13 |
+
# | | |
|
14 |
+
# | | image |
|
15 |
+
# | | |
|
16 |
+
# | | |
|
17 |
+
# | | |
|
18 |
+
# | |---------------------------|
|
19 |
+
# v
|
20 |
+
# [ y: range=-1.0~1.0; h: range=0~H ]
|
21 |
+
|
22 |
+
def simple_nms(scores, nms_radius: int):
|
23 |
+
""" Fast Non-maximum suppression to remove nearby points """
|
24 |
+
assert (nms_radius >= 0)
|
25 |
+
|
26 |
+
def max_pool(x):
|
27 |
+
return torch.nn.functional.max_pool2d(
|
28 |
+
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
|
29 |
+
|
30 |
+
zeros = torch.zeros_like(scores)
|
31 |
+
max_mask = scores == max_pool(scores)
|
32 |
+
|
33 |
+
for _ in range(2):
|
34 |
+
supp_mask = max_pool(max_mask.float()) > 0
|
35 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
36 |
+
new_max_mask = supp_scores == max_pool(supp_scores)
|
37 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
38 |
+
return torch.where(max_mask, scores, zeros)
|
39 |
+
|
40 |
+
|
41 |
+
"""
|
42 |
+
"XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
|
43 |
+
https://www.verlab.dcc.ufmg.br/descriptors/xfeat_cvpr24/
|
44 |
+
"""
|
45 |
+
|
46 |
+
import torch
|
47 |
+
import torch.nn as nn
|
48 |
+
import torch.nn.functional as F
|
49 |
+
|
50 |
+
class InterpolateSparse2d(nn.Module):
|
51 |
+
""" Efficiently interpolate tensor at given sparse 2D positions. """
|
52 |
+
def __init__(self, mode = 'bicubic', align_corners = False):
|
53 |
+
super().__init__()
|
54 |
+
self.mode = mode
|
55 |
+
self.align_corners = align_corners
|
56 |
+
|
57 |
+
def normgrid(self, x, H, W):
|
58 |
+
""" Normalize coords to [-1,1]. """
|
59 |
+
return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
|
60 |
+
|
61 |
+
def forward(self, x, pos, H, W):
|
62 |
+
"""
|
63 |
+
Input
|
64 |
+
x: [B, C, H, W] feature tensor
|
65 |
+
pos: [B, N, 2] tensor of positions
|
66 |
+
H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
|
67 |
+
|
68 |
+
Returns
|
69 |
+
[B, N, C] sampled channels at 2d positions
|
70 |
+
"""
|
71 |
+
grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
|
72 |
+
x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
|
73 |
+
return x.permute(0,2,3,1).squeeze(-2)
|
74 |
+
|
75 |
+
|
76 |
+
class SoftDetect(nn.Module):
|
77 |
+
def __init__(self, radius=2, top_k=0, scores_th=0.2, n_limit=20000):
|
78 |
+
"""
|
79 |
+
Args:
|
80 |
+
radius: soft detection radius, kernel size is (2 * radius + 1)
|
81 |
+
top_k: top_k > 0: return top k keypoints
|
82 |
+
scores_th: top_k <= 0 threshold mode: scores_th > 0: return keypoints with scores>scores_th
|
83 |
+
else: return keypoints with scores > scores.mean()
|
84 |
+
n_limit: max number of keypoint in threshold mode
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
self.radius = radius
|
88 |
+
self.top_k = top_k
|
89 |
+
self.scores_th = scores_th
|
90 |
+
self.n_limit = n_limit
|
91 |
+
self.kernel_size = 2 * self.radius + 1
|
92 |
+
self.temperature = 0.1 # tuned temperature
|
93 |
+
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
|
94 |
+
self.sample_descriptor = InterpolateSparse2d('bicubic')
|
95 |
+
# local xy grid
|
96 |
+
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
|
97 |
+
# (kernel_size*kernel_size) x 2 : (w,h)
|
98 |
+
self.hw_grid = torch.stack(torch.meshgrid([x, x])).view(2, -1).t()[:, [1, 0]]
|
99 |
+
|
100 |
+
def detect_keypoints(self, scores_map, normalized_coordinates=True):
|
101 |
+
b, c, h, w = scores_map.shape
|
102 |
+
scores_nograd = scores_map.detach()
|
103 |
+
|
104 |
+
# nms_scores = simple_nms(scores_nograd, self.radius)
|
105 |
+
nms_scores = simple_nms(scores_nograd, 2)
|
106 |
+
|
107 |
+
# remove border
|
108 |
+
nms_scores[:, :, :self.radius + 1, :] = 0
|
109 |
+
nms_scores[:, :, :, :self.radius + 1] = 0
|
110 |
+
nms_scores[:, :, h - self.radius:, :] = 0
|
111 |
+
nms_scores[:, :, :, w - self.radius:] = 0
|
112 |
+
|
113 |
+
# detect keypoints without grad
|
114 |
+
if self.top_k > 0:
|
115 |
+
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
|
116 |
+
indices_keypoints = topk.indices # B x top_k
|
117 |
+
else:
|
118 |
+
if self.scores_th > 0:
|
119 |
+
masks = nms_scores > self.scores_th
|
120 |
+
if masks.sum() == 0:
|
121 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
122 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
123 |
+
else:
|
124 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
125 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
126 |
+
masks = masks.reshape(b, -1)
|
127 |
+
|
128 |
+
indices_keypoints = [] # list, B x (any size)
|
129 |
+
scores_view = scores_nograd.reshape(b, -1)
|
130 |
+
for mask, scores in zip(masks, scores_view):
|
131 |
+
indices = mask.nonzero(as_tuple=False)[:, 0]
|
132 |
+
if len(indices) > self.n_limit:
|
133 |
+
kpts_sc = scores[indices]
|
134 |
+
sort_idx = kpts_sc.sort(descending=True)[1]
|
135 |
+
sel_idx = sort_idx[:self.n_limit]
|
136 |
+
indices = indices[sel_idx]
|
137 |
+
indices_keypoints.append(indices)
|
138 |
+
|
139 |
+
# detect soft keypoints with grad backpropagation
|
140 |
+
patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
|
141 |
+
self.hw_grid = self.hw_grid.to(patches) # to device
|
142 |
+
keypoints = []
|
143 |
+
scoredispersitys = []
|
144 |
+
kptscores = []
|
145 |
+
for b_idx in range(b):
|
146 |
+
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
147 |
+
indices_kpt = indices_keypoints[b_idx] # one dimension vector, say its size is M
|
148 |
+
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
149 |
+
|
150 |
+
# max is detached to prevent undesired backprop loops in the graph
|
151 |
+
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
152 |
+
x_exp = ((patch_scores - max_v) / self.temperature).exp() # M * (kernel**2), in [0, 1]
|
153 |
+
|
154 |
+
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
155 |
+
xy_residual = x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None] # Soft-argmax, Mx2
|
156 |
+
|
157 |
+
hw_grid_dist2 = torch.norm((self.hw_grid[None, :, :] - xy_residual[:, None, :]) / self.radius,
|
158 |
+
dim=-1) ** 2
|
159 |
+
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
160 |
+
|
161 |
+
# compute result keypoints
|
162 |
+
keypoints_xy_nms = torch.stack([indices_kpt % w, indices_kpt // w], dim=1) # Mx2
|
163 |
+
keypoints_xy = keypoints_xy_nms + xy_residual
|
164 |
+
if normalized_coordinates:
|
165 |
+
keypoints_xy = keypoints_xy / keypoints_xy.new_tensor([w - 1, h - 1]) * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
166 |
+
|
167 |
+
kptscore = torch.nn.functional.grid_sample(scores_map[b_idx].unsqueeze(0), keypoints_xy.view(1, 1, -1, 2),
|
168 |
+
mode='bilinear', align_corners=True)[0, 0, 0, :] # CxN
|
169 |
+
|
170 |
+
keypoints.append(keypoints_xy)
|
171 |
+
scoredispersitys.append(scoredispersity)
|
172 |
+
kptscores.append(kptscore)
|
173 |
+
|
174 |
+
return keypoints, scoredispersitys, kptscores
|
175 |
+
|
176 |
+
def forward(self, scores_map, normalized_coordinates=True):
|
177 |
+
"""
|
178 |
+
:param scores_map: Bx1xHxW
|
179 |
+
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1.0 ~ 1.0
|
180 |
+
"""
|
181 |
+
B, _, H, W = scores_map.shape
|
182 |
+
|
183 |
+
keypoints, scoredispersitys, kptscores = self.detect_keypoints(scores_map,
|
184 |
+
normalized_coordinates)
|
185 |
+
|
186 |
+
# keypoints: B M 2
|
187 |
+
# scoredispersitys:
|
188 |
+
return keypoints, kptscores, scoredispersitys
|
189 |
+
|
190 |
+
import torch
|
191 |
+
import torch.nn as nn
|
192 |
+
|
193 |
+
class Detect(nn.Module):
|
194 |
+
def __init__(self, stride=4, top_k=0, scores_th=0, n_limit=20000):
|
195 |
+
super().__init__()
|
196 |
+
self.stride = stride
|
197 |
+
self.top_k = top_k
|
198 |
+
self.scores_th = scores_th
|
199 |
+
self.n_limit = n_limit
|
200 |
+
|
201 |
+
def forward(self, scores, coords, w, h):
|
202 |
+
"""
|
203 |
+
scores: B x N x 1 (keypoint confidence scores)
|
204 |
+
coords: B x N x 2 (offsets within stride x stride window)
|
205 |
+
w, h: Image dimensions
|
206 |
+
"""
|
207 |
+
b, n, _ = scores.shape
|
208 |
+
kpts_list = []
|
209 |
+
scores_list = []
|
210 |
+
|
211 |
+
for b_idx in range(b):
|
212 |
+
score = scores[b_idx].squeeze(-1) # Shape: (N,)
|
213 |
+
coord = coords[b_idx] # Shape: (N, 2)
|
214 |
+
|
215 |
+
# Apply score thresholding
|
216 |
+
if self.scores_th >= 0:
|
217 |
+
valid = score > self.scores_th
|
218 |
+
else:
|
219 |
+
valid = score > score.mean()
|
220 |
+
|
221 |
+
valid_indices = valid.nonzero(as_tuple=True)[0] # Get valid indices
|
222 |
+
if valid_indices.numel() == 0:
|
223 |
+
kpts_list.append(torch.empty((0, 2), device=scores.device))
|
224 |
+
scores_list.append(torch.empty((0,), device=scores.device))
|
225 |
+
continue
|
226 |
+
|
227 |
+
# Compute keypoint locations in original image space
|
228 |
+
i_ids = valid_indices # Indices where keypoints exist
|
229 |
+
kpts = torch.stack([i_ids % w, i_ids // w], dim=1).to(torch.float) * self.stride # Grid position
|
230 |
+
kpts += coord[i_ids] * self.stride # Apply offset
|
231 |
+
|
232 |
+
# Normalize keypoints to [-1, 1] range
|
233 |
+
kpts = (kpts / torch.tensor([w - 1, h - 1], device=kpts.device, dtype=kpts.dtype)) * 2 - 1
|
234 |
+
|
235 |
+
# Filter top-k keypoints if needed
|
236 |
+
scores_valid = score[valid_indices]
|
237 |
+
if self.top_k > 0 and len(kpts) > self.top_k:
|
238 |
+
topk = torch.topk(scores_valid, self.top_k, dim=0)
|
239 |
+
kpts = kpts[topk.indices]
|
240 |
+
scores_valid = topk.values
|
241 |
+
elif self.top_k < 0:
|
242 |
+
if len(kpts) > self.n_limit:
|
243 |
+
sorted_idx = scores_valid.argsort(descending=True)[:self.n_limit]
|
244 |
+
kpts = kpts[sorted_idx]
|
245 |
+
scores_valid = scores_valid[sorted_idx]
|
246 |
+
|
247 |
+
kpts_list.append(kpts)
|
248 |
+
scores_list.append(scores_valid)
|
249 |
+
|
250 |
+
return kpts_list, scores_list
|
imcui/third_party/rdd/RDD/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .misc import *
|
imcui/third_party/rdd/RDD/utils/misc.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from Deformable DETR
|
2 |
+
# https://github.com/fundamentalvision/Deformable-DETR
|
3 |
+
# ------------------------------------------------------------------------
|
4 |
+
# Deformable DETR
|
5 |
+
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
6 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
7 |
+
# ------------------------------------------------------------------------
|
8 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
9 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
10 |
+
# ------------------------------------------------------------------------
|
11 |
+
import os
|
12 |
+
import subprocess
|
13 |
+
import time
|
14 |
+
from collections import defaultdict, deque
|
15 |
+
import datetime
|
16 |
+
import pickle
|
17 |
+
from typing import Optional, List
|
18 |
+
import yaml
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.distributed as dist
|
22 |
+
from torch import Tensor
|
23 |
+
|
24 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
25 |
+
import torchvision
|
26 |
+
if float(torchvision.__version__.split('.')[1]) < 0.5:
|
27 |
+
import math
|
28 |
+
from torchvision.ops.misc import _NewEmptyTensorOp
|
29 |
+
def _check_size_scale_factor(dim, size, scale_factor):
|
30 |
+
# type: (int, Optional[List[int]], Optional[float]) -> None
|
31 |
+
if size is None and scale_factor is None:
|
32 |
+
raise ValueError("either size or scale_factor should be defined")
|
33 |
+
if size is not None and scale_factor is not None:
|
34 |
+
raise ValueError("only one of size or scale_factor should be defined")
|
35 |
+
if not (scale_factor is not None and len(scale_factor) != dim):
|
36 |
+
raise ValueError(
|
37 |
+
"scale_factor shape must match input shape. "
|
38 |
+
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
|
39 |
+
)
|
40 |
+
def _output_size(dim, input, size, scale_factor):
|
41 |
+
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
|
42 |
+
assert dim == 2
|
43 |
+
_check_size_scale_factor(dim, size, scale_factor)
|
44 |
+
if size is not None:
|
45 |
+
return size
|
46 |
+
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
|
47 |
+
assert scale_factor is not None and isinstance(scale_factor, (int, float))
|
48 |
+
scale_factors = [scale_factor, scale_factor]
|
49 |
+
# math.floor might return float in py2.7
|
50 |
+
return [
|
51 |
+
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
|
52 |
+
]
|
53 |
+
elif float(torchvision.__version__.split('.')[1]) < 7:
|
54 |
+
from torchvision.ops import _new_empty_tensor
|
55 |
+
from torchvision.ops.misc import _output_size
|
56 |
+
|
57 |
+
|
58 |
+
class SmoothedValue(object):
|
59 |
+
"""Track a series of values and provide access to smoothed values over a
|
60 |
+
window or the global series average.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, window_size=20, fmt=None):
|
64 |
+
if fmt is None:
|
65 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
66 |
+
self.deque = deque(maxlen=window_size)
|
67 |
+
self.total = 0.0
|
68 |
+
self.count = 0
|
69 |
+
self.fmt = fmt
|
70 |
+
|
71 |
+
def update(self, value, n=1):
|
72 |
+
self.deque.append(value)
|
73 |
+
self.count += n
|
74 |
+
self.total += value * n
|
75 |
+
|
76 |
+
def synchronize_between_processes(self):
|
77 |
+
"""
|
78 |
+
Warning: does not synchronize the deque!
|
79 |
+
"""
|
80 |
+
if not is_dist_avail_and_initialized():
|
81 |
+
return
|
82 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
83 |
+
dist.barrier()
|
84 |
+
dist.all_reduce(t)
|
85 |
+
t = t.tolist()
|
86 |
+
self.count = int(t[0])
|
87 |
+
self.total = t[1]
|
88 |
+
|
89 |
+
@property
|
90 |
+
def median(self):
|
91 |
+
d = torch.tensor(list(self.deque))
|
92 |
+
return d.median().item()
|
93 |
+
|
94 |
+
@property
|
95 |
+
def avg(self):
|
96 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
97 |
+
return d.mean().item()
|
98 |
+
|
99 |
+
@property
|
100 |
+
def global_avg(self):
|
101 |
+
return self.total / self.count
|
102 |
+
|
103 |
+
@property
|
104 |
+
def max(self):
|
105 |
+
return max(self.deque)
|
106 |
+
|
107 |
+
@property
|
108 |
+
def value(self):
|
109 |
+
return self.deque[-1]
|
110 |
+
|
111 |
+
def __str__(self):
|
112 |
+
return self.fmt.format(
|
113 |
+
median=self.median,
|
114 |
+
avg=self.avg,
|
115 |
+
global_avg=self.global_avg,
|
116 |
+
max=self.max,
|
117 |
+
value=self.value)
|
118 |
+
|
119 |
+
|
120 |
+
def all_gather(data):
|
121 |
+
"""
|
122 |
+
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
123 |
+
Args:
|
124 |
+
data: any picklable object
|
125 |
+
Returns:
|
126 |
+
list[data]: list of data gathered from each rank
|
127 |
+
"""
|
128 |
+
world_size = get_world_size()
|
129 |
+
if world_size == 1:
|
130 |
+
return [data]
|
131 |
+
|
132 |
+
# serialized to a Tensor
|
133 |
+
buffer = pickle.dumps(data)
|
134 |
+
storage = torch.ByteStorage.from_buffer(buffer)
|
135 |
+
tensor = torch.ByteTensor(storage).to("cuda")
|
136 |
+
|
137 |
+
# obtain Tensor size of each rank
|
138 |
+
local_size = torch.tensor([tensor.numel()], device="cuda")
|
139 |
+
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
140 |
+
dist.all_gather(size_list, local_size)
|
141 |
+
size_list = [int(size.item()) for size in size_list]
|
142 |
+
max_size = max(size_list)
|
143 |
+
|
144 |
+
# receiving Tensor from all ranks
|
145 |
+
# we pad the tensor because torch all_gather does not support
|
146 |
+
# gathering tensors of different shapes
|
147 |
+
tensor_list = []
|
148 |
+
for _ in size_list:
|
149 |
+
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
150 |
+
if local_size != max_size:
|
151 |
+
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
152 |
+
tensor = torch.cat((tensor, padding), dim=0)
|
153 |
+
dist.all_gather(tensor_list, tensor)
|
154 |
+
|
155 |
+
data_list = []
|
156 |
+
for size, tensor in zip(size_list, tensor_list):
|
157 |
+
buffer = tensor.cpu().numpy().tobytes()[:size]
|
158 |
+
data_list.append(pickle.loads(buffer))
|
159 |
+
|
160 |
+
return data_list
|
161 |
+
|
162 |
+
|
163 |
+
def reduce_dict(input_dict, average=True):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
input_dict (dict): all the values will be reduced
|
167 |
+
average (bool): whether to do average or sum
|
168 |
+
Reduce the values in the dictionary from all processes so that all processes
|
169 |
+
have the averaged results. Returns a dict with the same fields as
|
170 |
+
input_dict, after reduction.
|
171 |
+
"""
|
172 |
+
world_size = get_world_size()
|
173 |
+
if world_size < 2:
|
174 |
+
return input_dict
|
175 |
+
with torch.no_grad():
|
176 |
+
names = []
|
177 |
+
values = []
|
178 |
+
# sort the keys so that they are consistent across processes
|
179 |
+
for k in sorted(input_dict.keys()):
|
180 |
+
names.append(k)
|
181 |
+
values.append(input_dict[k])
|
182 |
+
values = torch.stack(values, dim=0)
|
183 |
+
dist.all_reduce(values)
|
184 |
+
if average:
|
185 |
+
values /= world_size
|
186 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
187 |
+
return reduced_dict
|
188 |
+
|
189 |
+
|
190 |
+
class MetricLogger(object):
|
191 |
+
def __init__(self, delimiter="\t"):
|
192 |
+
self.meters = defaultdict(SmoothedValue)
|
193 |
+
self.delimiter = delimiter
|
194 |
+
|
195 |
+
def update(self, **kwargs):
|
196 |
+
for k, v in kwargs.items():
|
197 |
+
if isinstance(v, torch.Tensor):
|
198 |
+
v = v.item()
|
199 |
+
assert isinstance(v, (float, int))
|
200 |
+
self.meters[k].update(v)
|
201 |
+
|
202 |
+
def __getattr__(self, attr):
|
203 |
+
if attr in self.meters:
|
204 |
+
return self.meters[attr]
|
205 |
+
if attr in self.__dict__:
|
206 |
+
return self.__dict__[attr]
|
207 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
208 |
+
type(self).__name__, attr))
|
209 |
+
|
210 |
+
def __str__(self):
|
211 |
+
loss_str = []
|
212 |
+
for name, meter in self.meters.items():
|
213 |
+
loss_str.append(
|
214 |
+
"{}: {}".format(name, str(meter))
|
215 |
+
)
|
216 |
+
return self.delimiter.join(loss_str)
|
217 |
+
|
218 |
+
def synchronize_between_processes(self):
|
219 |
+
for meter in self.meters.values():
|
220 |
+
meter.synchronize_between_processes()
|
221 |
+
|
222 |
+
def add_meter(self, name, meter):
|
223 |
+
self.meters[name] = meter
|
224 |
+
|
225 |
+
def log_every(self, iterable, print_freq, header=None):
|
226 |
+
i = 0
|
227 |
+
if not header:
|
228 |
+
header = ''
|
229 |
+
start_time = time.time()
|
230 |
+
end = time.time()
|
231 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
232 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
233 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
234 |
+
if torch.cuda.is_available():
|
235 |
+
log_msg = self.delimiter.join([
|
236 |
+
header,
|
237 |
+
'[{0' + space_fmt + '}/{1}]',
|
238 |
+
'eta: {eta}',
|
239 |
+
'{meters}',
|
240 |
+
'time: {time}',
|
241 |
+
'data: {data}',
|
242 |
+
'max mem: {memory:.0f}'
|
243 |
+
])
|
244 |
+
else:
|
245 |
+
log_msg = self.delimiter.join([
|
246 |
+
header,
|
247 |
+
'[{0' + space_fmt + '}/{1}]',
|
248 |
+
'eta: {eta}',
|
249 |
+
'{meters}',
|
250 |
+
'time: {time}',
|
251 |
+
'data: {data}'
|
252 |
+
])
|
253 |
+
MB = 1024.0 * 1024.0
|
254 |
+
for obj in iterable:
|
255 |
+
data_time.update(time.time() - end)
|
256 |
+
yield obj
|
257 |
+
iter_time.update(time.time() - end)
|
258 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
259 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
260 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
261 |
+
if torch.cuda.is_available():
|
262 |
+
print(log_msg.format(
|
263 |
+
i, len(iterable), eta=eta_string,
|
264 |
+
meters=str(self),
|
265 |
+
time=str(iter_time), data=str(data_time),
|
266 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
267 |
+
else:
|
268 |
+
print(log_msg.format(
|
269 |
+
i, len(iterable), eta=eta_string,
|
270 |
+
meters=str(self),
|
271 |
+
time=str(iter_time), data=str(data_time)))
|
272 |
+
i += 1
|
273 |
+
end = time.time()
|
274 |
+
total_time = time.time() - start_time
|
275 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
276 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
277 |
+
header, total_time_str, total_time / len(iterable)))
|
278 |
+
|
279 |
+
|
280 |
+
def get_sha():
|
281 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
282 |
+
|
283 |
+
def _run(command):
|
284 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
285 |
+
sha = 'N/A'
|
286 |
+
diff = "clean"
|
287 |
+
branch = 'N/A'
|
288 |
+
try:
|
289 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
290 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
291 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
292 |
+
diff = "has uncommited changes" if diff else "clean"
|
293 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
294 |
+
except Exception:
|
295 |
+
pass
|
296 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
297 |
+
return message
|
298 |
+
|
299 |
+
|
300 |
+
def collate_fn(batch):
|
301 |
+
batch = list(zip(*batch))
|
302 |
+
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
303 |
+
return tuple(batch)
|
304 |
+
|
305 |
+
|
306 |
+
def _max_by_axis(the_list):
|
307 |
+
# type: (List[List[int]]) -> List[int]
|
308 |
+
maxes = the_list[0]
|
309 |
+
for sublist in the_list[1:]:
|
310 |
+
for index, item in enumerate(sublist):
|
311 |
+
maxes[index] = max(maxes[index], item)
|
312 |
+
return maxes
|
313 |
+
|
314 |
+
|
315 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
316 |
+
# TODO make this more general
|
317 |
+
if tensor_list[0].ndim == 3:
|
318 |
+
# TODO make it support different-sized images
|
319 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
320 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
321 |
+
batch_shape = [len(tensor_list)] + max_size
|
322 |
+
b, c, h, w = batch_shape
|
323 |
+
dtype = tensor_list[0].dtype
|
324 |
+
device = tensor_list[0].device
|
325 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
326 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
327 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
328 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
329 |
+
m[: img.shape[1], :img.shape[2]] = False
|
330 |
+
else:
|
331 |
+
raise ValueError('not supported')
|
332 |
+
return NestedTensor(tensor, mask)
|
333 |
+
|
334 |
+
|
335 |
+
class NestedTensor(object):
|
336 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
337 |
+
self.tensors = tensors
|
338 |
+
self.mask = mask
|
339 |
+
|
340 |
+
def to(self, device, non_blocking=False):
|
341 |
+
# type: (Device) -> NestedTensor # noqa
|
342 |
+
cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
|
343 |
+
mask = self.mask
|
344 |
+
if mask is not None:
|
345 |
+
assert mask is not None
|
346 |
+
cast_mask = mask.to(device, non_blocking=non_blocking)
|
347 |
+
else:
|
348 |
+
cast_mask = None
|
349 |
+
return NestedTensor(cast_tensor, cast_mask)
|
350 |
+
|
351 |
+
def record_stream(self, *args, **kwargs):
|
352 |
+
self.tensors.record_stream(*args, **kwargs)
|
353 |
+
if self.mask is not None:
|
354 |
+
self.mask.record_stream(*args, **kwargs)
|
355 |
+
|
356 |
+
def decompose(self):
|
357 |
+
return self.tensors, self.mask
|
358 |
+
|
359 |
+
def __repr__(self):
|
360 |
+
return str(self.tensors)
|
361 |
+
|
362 |
+
|
363 |
+
def setup_for_distributed(is_master):
|
364 |
+
"""
|
365 |
+
This function disables printing when not in master process
|
366 |
+
"""
|
367 |
+
import builtins as __builtin__
|
368 |
+
builtin_print = __builtin__.print
|
369 |
+
|
370 |
+
def print(*args, **kwargs):
|
371 |
+
force = kwargs.pop('force', False)
|
372 |
+
if is_master or force:
|
373 |
+
builtin_print(*args, **kwargs)
|
374 |
+
|
375 |
+
__builtin__.print = print
|
376 |
+
|
377 |
+
|
378 |
+
def is_dist_avail_and_initialized():
|
379 |
+
if not dist.is_available():
|
380 |
+
return False
|
381 |
+
if not dist.is_initialized():
|
382 |
+
return False
|
383 |
+
return True
|
384 |
+
|
385 |
+
|
386 |
+
def get_world_size():
|
387 |
+
if not is_dist_avail_and_initialized():
|
388 |
+
return 1
|
389 |
+
return dist.get_world_size()
|
390 |
+
|
391 |
+
|
392 |
+
def get_rank():
|
393 |
+
if not is_dist_avail_and_initialized():
|
394 |
+
return 0
|
395 |
+
return dist.get_rank()
|
396 |
+
|
397 |
+
|
398 |
+
def get_local_size():
|
399 |
+
if not is_dist_avail_and_initialized():
|
400 |
+
return 1
|
401 |
+
return int(os.environ['LOCAL_SIZE'])
|
402 |
+
|
403 |
+
|
404 |
+
def get_local_rank():
|
405 |
+
if not is_dist_avail_and_initialized():
|
406 |
+
return 0
|
407 |
+
return int(os.environ['LOCAL_RANK'])
|
408 |
+
|
409 |
+
|
410 |
+
def is_main_process():
|
411 |
+
return get_rank() == 0
|
412 |
+
|
413 |
+
|
414 |
+
def save_on_master(*args, **kwargs):
|
415 |
+
if is_main_process():
|
416 |
+
torch.save(*args, **kwargs)
|
417 |
+
|
418 |
+
|
419 |
+
def init_distributed_mode(args):
|
420 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
421 |
+
args.rank = int(os.environ["RANK"])
|
422 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
423 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
424 |
+
args.dist_url = 'env://'
|
425 |
+
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
|
426 |
+
elif 'SLURM_PROCID' in os.environ:
|
427 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
428 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
429 |
+
node_list = os.environ['SLURM_NODELIST']
|
430 |
+
num_gpus = torch.cuda.device_count()
|
431 |
+
addr = subprocess.getoutput(
|
432 |
+
'scontrol show hostname {} | head -n1'.format(node_list))
|
433 |
+
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
|
434 |
+
os.environ['MASTER_ADDR'] = addr
|
435 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
436 |
+
os.environ['RANK'] = str(proc_id)
|
437 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
438 |
+
os.environ['LOCAL_SIZE'] = str(num_gpus)
|
439 |
+
args.dist_url = 'env://'
|
440 |
+
args.world_size = ntasks
|
441 |
+
args.rank = proc_id
|
442 |
+
args.gpu = proc_id % num_gpus
|
443 |
+
else:
|
444 |
+
print('Not using distributed mode')
|
445 |
+
args.distributed = False
|
446 |
+
return
|
447 |
+
|
448 |
+
args.distributed = True
|
449 |
+
|
450 |
+
torch.cuda.set_device(args.gpu)
|
451 |
+
args.dist_backend = 'nccl'
|
452 |
+
print('| distributed init (rank {}): {}'.format(
|
453 |
+
args.rank, args.dist_url), flush=True)
|
454 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
455 |
+
world_size=args.world_size, rank=args.rank)
|
456 |
+
torch.distributed.barrier()
|
457 |
+
setup_for_distributed(args.rank == 0)
|
458 |
+
|
459 |
+
|
460 |
+
@torch.no_grad()
|
461 |
+
def accuracy(output, target, topk=(1,)):
|
462 |
+
"""Computes the precision@k for the specified values of k"""
|
463 |
+
if target.numel() == 0:
|
464 |
+
return [torch.zeros([], device=output.device)]
|
465 |
+
maxk = max(topk)
|
466 |
+
batch_size = target.size(0)
|
467 |
+
|
468 |
+
_, pred = output.topk(maxk, 1, True, True)
|
469 |
+
pred = pred.t()
|
470 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
471 |
+
|
472 |
+
res = []
|
473 |
+
for k in topk:
|
474 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
475 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
476 |
+
return res
|
477 |
+
|
478 |
+
|
479 |
+
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
480 |
+
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
481 |
+
"""
|
482 |
+
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
483 |
+
This will eventually be supported natively by PyTorch, and this
|
484 |
+
class can go away.
|
485 |
+
"""
|
486 |
+
if float(torchvision.__version__[:3]) < 0.7:
|
487 |
+
if input.numel() > 0:
|
488 |
+
return torch.nn.functional.interpolate(
|
489 |
+
input, size, scale_factor, mode, align_corners
|
490 |
+
)
|
491 |
+
|
492 |
+
output_shape = _output_size(2, input, size, scale_factor)
|
493 |
+
output_shape = list(input.shape[:-2]) + list(output_shape)
|
494 |
+
if float(torchvision.__version__[:3]) < 0.5:
|
495 |
+
return _NewEmptyTensorOp.apply(input, output_shape)
|
496 |
+
return _new_empty_tensor(input, output_shape)
|
497 |
+
else:
|
498 |
+
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
499 |
+
|
500 |
+
|
501 |
+
def get_total_grad_norm(parameters, norm_type=2):
|
502 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
503 |
+
norm_type = float(norm_type)
|
504 |
+
device = parameters[0].grad.device
|
505 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
|
506 |
+
norm_type)
|
507 |
+
return total_norm
|
508 |
+
|
509 |
+
def inverse_sigmoid(x, eps=1e-5):
|
510 |
+
x = x.clamp(min=0, max=1)
|
511 |
+
x1 = x.clamp(min=eps)
|
512 |
+
x2 = (1 - x).clamp(min=eps)
|
513 |
+
return torch.log(x1/x2)
|
514 |
+
|
515 |
+
|
516 |
+
def to_pixel_coords(flow, h1, w1):
|
517 |
+
flow = (
|
518 |
+
torch.stack(
|
519 |
+
(
|
520 |
+
w1 * (flow[..., 0] + 1) / 2,
|
521 |
+
h1 * (flow[..., 1] + 1) / 2,
|
522 |
+
),
|
523 |
+
axis=-1,
|
524 |
+
)
|
525 |
+
)
|
526 |
+
return flow
|
527 |
+
|
528 |
+
def read_config(file_path):
|
529 |
+
with open(file_path, 'r') as file:
|
530 |
+
config = yaml.safe_load(file)
|
531 |
+
return config
|
imcui/third_party/rdd/README.md
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## RDD: Robust Feature Detector and Descriptor using Deformable Transformer (CVPR 2025)
|
2 |
+
[Gonglin Chen](https://xtcpete.com/) · [Tianwen Fu](https://twfu.me/) · [Haiwei Chen](https://scholar.google.com/citations?user=LVWRssoAAAAJ&hl=en) · [Wenbin Teng](https://wbteng9526.github.io/) · [Hanyuan Xiao](https://corneliushsiao.github.io/index.html) · [Yajie Zhao](https://ict.usc.edu/about-us/leadership/research-leadership/yajie-zhao/)
|
3 |
+
|
4 |
+
[Project Page](https://xtcpete.github.io/rdd/)
|
5 |
+
|
6 |
+
## Table of Contents
|
7 |
+
- [Updates](#updates)
|
8 |
+
- [Installation](#installation)
|
9 |
+
- [Usage](#usage)
|
10 |
+
- [Inference](#inference)
|
11 |
+
- [Evaluation](#evaluation)
|
12 |
+
- [Training](#training)
|
13 |
+
- [Citation](#citation)
|
14 |
+
- [License](#license)
|
15 |
+
- [Acknowledgements](#acknowledgements)
|
16 |
+
|
17 |
+
## Updates
|
18 |
+
|
19 |
+
- SfM reconstruction through [COLMAP](https://github.com/colmap/colmap.git) added. We provide a ready-to-use [notebook](./demo_sfm.ipynb) for a simple example. Code adopted from [hloc](https://github.com/cvg/Hierarchical-Localization.git).
|
20 |
+
|
21 |
+
- Training code and new weights released.
|
22 |
+
|
23 |
+
- We have updated the training code compared to what was described in the paper. In the original setup, the RDD was trained on the MegaDepth and Air-to-Ground datasets by resizing all images to the training resolution. In this release, we retrained RDD on MegaDepth only, using a combination of resizing and cropping, a strategy used by [ALIKE](https://github.com/Shiaoming/ALIKE). This change significantly improves robustness.
|
24 |
+
|
25 |
+
<table>
|
26 |
+
<tr>
|
27 |
+
<th></th>
|
28 |
+
<th colspan="3">MegaDepth-1500</th>
|
29 |
+
<th colspan="3">MegaDepth-View</th>
|
30 |
+
<th colspan="3">Air-to-Ground</th>
|
31 |
+
</tr>
|
32 |
+
<tr>
|
33 |
+
<td></td>
|
34 |
+
<td>AUC 5°</td><td>AUC 10°</td><td>AUC 20°</td>
|
35 |
+
<td>AUC 5°</td><td>AUC 10°</td><td>AUC 20°</td>
|
36 |
+
<td>AUC 5°</td><td>AUC 10°</td><td>AUC 20°</td>
|
37 |
+
</tr>
|
38 |
+
<tr>
|
39 |
+
<td>RDD-v2</td>
|
40 |
+
<td>52.4</td><td>68.5</td><td>80.1</td>
|
41 |
+
<td>52.0</td><td>67.1</td><td>78.2</td>
|
42 |
+
<td>45.8</td><td>58.6</td><td>71.0</td>
|
43 |
+
</tr>
|
44 |
+
<tr>
|
45 |
+
<td>RDD-v1</td>
|
46 |
+
<td>48.2</td><td>65.2</td><td>78.3</td>
|
47 |
+
<td>38.3</td><td>53.1</td><td>65.6</td>
|
48 |
+
<td>41.4</td><td>56.0</td><td>67.8</td>
|
49 |
+
</tr>
|
50 |
+
<tr>
|
51 |
+
<td>RDD-v2+LG</td>
|
52 |
+
<td>53.3</td><td>69.8</td><td>82.0</td>
|
53 |
+
<td>59.0</td><td>74.2</td><td>84.0</td>
|
54 |
+
<td>54.8</td><td>69.0</td><td>79.1</td>
|
55 |
+
</tr>
|
56 |
+
<tr>
|
57 |
+
<td>RDD-v1+LG</td>
|
58 |
+
<td>52.3</td><td>68.9</td><td>81.8</td>
|
59 |
+
<td>54.2</td><td>69.3</td><td>80.3</td>
|
60 |
+
<td>55.1</td><td>68.9</td><td>78.9</td>
|
61 |
+
</tr>
|
62 |
+
</table>
|
63 |
+
|
64 |
+
## Installation
|
65 |
+
|
66 |
+
```bash
|
67 |
+
git clone --recursive https://github.com/xtcpete/rdd
|
68 |
+
cd RDD
|
69 |
+
|
70 |
+
# Create conda env
|
71 |
+
conda create -n rdd python=3.10 pip
|
72 |
+
conda activate rdd
|
73 |
+
|
74 |
+
# Install CUDA
|
75 |
+
conda install -c nvidia/label/cuda-11.8.0 cuda-toolkit
|
76 |
+
# Install torch
|
77 |
+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
|
78 |
+
# Install all dependencies
|
79 |
+
pip install -r requirements.txt
|
80 |
+
# Compile custom operations
|
81 |
+
cd ./RDD/models/ops
|
82 |
+
pip install -e .
|
83 |
+
```
|
84 |
+
|
85 |
+
We provide the [download link](https://drive.google.com/drive/folders/1QgVaqm4iTUCqbWb7_Fi6mX09EHTId0oA?usp=sharing) to:
|
86 |
+
- the MegaDepth-1500 test set
|
87 |
+
- the MegaDepth-View test set
|
88 |
+
- the Air-to-Ground test set
|
89 |
+
- 2 pretrained models, RDD and LightGlue for matching RDD
|
90 |
+
|
91 |
+
Create and unzip downloaded test data to the `data` folder.
|
92 |
+
|
93 |
+
Create and add weights to the `weights` folder and you are ready to go.
|
94 |
+
|
95 |
+
## Usage
|
96 |
+
For your convenience, we provide a ready-to-use [notebook](./demo_matching.ipynb) for some examples.
|
97 |
+
|
98 |
+
### Inference
|
99 |
+
|
100 |
+
```python
|
101 |
+
from RDD.RDD import build
|
102 |
+
|
103 |
+
RDD_model = build()
|
104 |
+
|
105 |
+
output = RDD_model.extract(torch.randn(1, 3, 480, 640))
|
106 |
+
```
|
107 |
+
|
108 |
+
### Evaluation
|
109 |
+
|
110 |
+
Please note that due to the different GPU architectures and the stochastic nature of RANSAC, you may observe slightly different results; however, they should be very close to those reported in the paper. To reproduce the number in paper, use v1 weights instead.
|
111 |
+
|
112 |
+
Results can be visualized by passing argument --plot
|
113 |
+
|
114 |
+
**MegaDepth-1500**
|
115 |
+
|
116 |
+
```bash
|
117 |
+
# Sparse matching
|
118 |
+
python ./benchmarks/mega_1500.py
|
119 |
+
|
120 |
+
# Dense matching
|
121 |
+
python ./benchmarks/mega_1500.py --method dense
|
122 |
+
|
123 |
+
# LightGlue
|
124 |
+
python ./benchmarks/mega_1500.py --method lightglue
|
125 |
+
```
|
126 |
+
|
127 |
+
**MegaDepth-View**
|
128 |
+
|
129 |
+
```bash
|
130 |
+
# Sparse matching
|
131 |
+
python ./benchmarks/mega_view.py
|
132 |
+
|
133 |
+
# Dense matching
|
134 |
+
python ./benchmarks/mega_view.py --method dense
|
135 |
+
|
136 |
+
# LightGlue
|
137 |
+
python ./benchmarks/mega_view.py --method lightglue
|
138 |
+
```
|
139 |
+
|
140 |
+
**Air-to-Ground**
|
141 |
+
|
142 |
+
```bash
|
143 |
+
# Sparse matching
|
144 |
+
python ./benchmarks/air_ground.py
|
145 |
+
|
146 |
+
# Dense matching
|
147 |
+
python ./benchmarks/air_ground.py --method dense
|
148 |
+
|
149 |
+
# LightGlue
|
150 |
+
python ./benchmarks/air_ground.py --method lightglue
|
151 |
+
```
|
152 |
+
|
153 |
+
### Training
|
154 |
+
|
155 |
+
1. Download MegaDepth dataset using [download.sh](./data/megadepth/download.sh) and megadepth_indices from [LoFTR](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets). Then the MegaDepth root folder should look like the following:
|
156 |
+
```bash
|
157 |
+
./data/megadepth/megadepth_indices # indices
|
158 |
+
./data/megadepth/depth_undistorted # depth maps
|
159 |
+
./data/megadepth/Undistorted_SfM # images and poses
|
160 |
+
./data/megadepth/scene_info # indices for training LightGlue
|
161 |
+
```
|
162 |
+
2. Then you can train RDD in two steps; Descriptor first
|
163 |
+
```bash
|
164 |
+
# distributed training with 8 gpus
|
165 |
+
python -m training.train --ckpt_save_path ./ckpt_descriptor --distributed --batch_size 32
|
166 |
+
|
167 |
+
# single gpu
|
168 |
+
python -m training.train --ckpt_save_path ./ckpt_descriptor
|
169 |
+
```
|
170 |
+
and then Detector
|
171 |
+
```bash
|
172 |
+
python -m training.train --ckpt_save_path ./ckpt_detector --weights ./ckpt_descriptor/RDD_best.pth --train_detector --training_res 480
|
173 |
+
```
|
174 |
+
|
175 |
+
I am working on recollecting the Air-to-Ground dataset because of licensing issues.
|
176 |
+
|
177 |
+
## Citation
|
178 |
+
```
|
179 |
+
@inproceedings{gonglin2025rdd,
|
180 |
+
title = {RDD: Robust Feature Detector and Descriptor using Deformable Transformer},
|
181 |
+
author = {Chen, Gonglin and Fu, Tianwen and Chen, Haiwei and Teng, Wenbin and Xiao, Hanyuan and Zhao, Yajie},
|
182 |
+
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
183 |
+
year = {2025}
|
184 |
+
}
|
185 |
+
```
|
186 |
+
|
187 |
+
|
188 |
+
## License
|
189 |
+
[](LICENSE)
|
190 |
+
|
191 |
+
## Acknowledgements
|
192 |
+
|
193 |
+
We thank these great repositories: [ALIKE](https://github.com/Shiaoming/ALIKE), [LoFTR](https://github.com/zju3dv/LoFTR), [DeDoDe](https://github.com/Parskatt/DeDoDe), [XFeat](https://github.com/verlab/accelerated_features), [LightGlue](https://github.com/cvg/LightGlue), [Kornia](https://github.com/kornia/kornia), and [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR), and many other inspiring works in the community.
|
194 |
+
|
195 |
+
LightGlue is trained with [Glue Factory](https://github.com/cvg/glue-factory).
|
196 |
+
|
197 |
+
Supported by the Intelligence Advanced Research Projects Activity (IARPA) via Department of Interior/Interior Business Center (DOI/IBC) contract number 140D0423C0075. The U.S. Government is authorized to reproduce and distribute reprints for governmental purposes notwithstanding any copyright annotation thereon. Disclaimer: The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of IARPA, DOI/IBC, or the U.S. Government. We would like to thank Yayue Chen for her help with visualization.
|
imcui/third_party/rdd/benchmarks/air_ground.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import tqdm
|
7 |
+
import cv2
|
8 |
+
import argparse
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib
|
11 |
+
from RDD.RDD_helper import RDD_helper
|
12 |
+
from RDD.RDD import build
|
13 |
+
import os
|
14 |
+
from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha
|
15 |
+
|
16 |
+
def make_matching_figure(
|
17 |
+
img0, img1, mkpts0, mkpts1, color,
|
18 |
+
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
|
19 |
+
# draw image pair
|
20 |
+
assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
21 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
22 |
+
axes[0].imshow(img0, cmap='gray')
|
23 |
+
axes[1].imshow(img1, cmap='gray')
|
24 |
+
for i in range(2): # clear all frames
|
25 |
+
axes[i].get_yaxis().set_ticks([])
|
26 |
+
axes[i].get_xaxis().set_ticks([])
|
27 |
+
for spine in axes[i].spines.values():
|
28 |
+
spine.set_visible(False)
|
29 |
+
plt.tight_layout(pad=1)
|
30 |
+
|
31 |
+
if kpts0 is not None:
|
32 |
+
assert kpts1 is not None
|
33 |
+
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
|
34 |
+
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
|
35 |
+
|
36 |
+
# draw matches
|
37 |
+
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
38 |
+
fig.canvas.draw()
|
39 |
+
transFigure = fig.transFigure.inverted()
|
40 |
+
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
41 |
+
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
42 |
+
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
43 |
+
(fkpts0[i, 1], fkpts1[i, 1]),
|
44 |
+
transform=fig.transFigure, c=color[i], linewidth=1)
|
45 |
+
for i in range(len(mkpts0))]
|
46 |
+
|
47 |
+
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
48 |
+
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
49 |
+
|
50 |
+
# put txts
|
51 |
+
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
|
52 |
+
fig.text(
|
53 |
+
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
|
54 |
+
fontsize=15, va='top', ha='left', color=txt_color)
|
55 |
+
|
56 |
+
# save or return figure
|
57 |
+
if path:
|
58 |
+
plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
|
59 |
+
plt.close()
|
60 |
+
else:
|
61 |
+
return fig
|
62 |
+
|
63 |
+
def error_colormap(err, thr, alpha=1.0):
|
64 |
+
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
65 |
+
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
66 |
+
return np.clip(
|
67 |
+
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
|
68 |
+
|
69 |
+
def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None):
|
70 |
+
conf_thr = 1e-4
|
71 |
+
|
72 |
+
img0 = np.array(img0)
|
73 |
+
img1 = np.array(img1)
|
74 |
+
|
75 |
+
kpts0 = kpts0
|
76 |
+
kpts1 = kpts1
|
77 |
+
|
78 |
+
epi_errs = epi_errs.cpu().numpy()
|
79 |
+
correct_mask = epi_errs < conf_thr
|
80 |
+
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
81 |
+
n_correct = np.sum(correct_mask)
|
82 |
+
|
83 |
+
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
84 |
+
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
85 |
+
|
86 |
+
# matching info
|
87 |
+
if alpha == 'dynamic':
|
88 |
+
alpha = dynamic_alpha(len(correct_mask))
|
89 |
+
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
90 |
+
|
91 |
+
text = [
|
92 |
+
f'#Matches {len(kpts0)}',
|
93 |
+
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
|
94 |
+
f'e_t: {e_t:.2f} | e_R: {e_R:.2f}',
|
95 |
+
]
|
96 |
+
|
97 |
+
# make the figure
|
98 |
+
figure = make_matching_figure(img0, img1, kpts0, kpts1,
|
99 |
+
color, text=text, path=path)
|
100 |
+
return figure
|
101 |
+
|
102 |
+
class AirGroundPoseMNNBenchmark:
|
103 |
+
def __init__(self, data_root="./data/air_ground", scene_names = None) -> None:
|
104 |
+
if scene_names is None:
|
105 |
+
self.scene_names = [
|
106 |
+
"indices.npz",
|
107 |
+
]
|
108 |
+
# self.scene_names = ["0022_0.5_0.7.npz",]
|
109 |
+
else:
|
110 |
+
self.scene_names = scene_names
|
111 |
+
self.scenes = [
|
112 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
113 |
+
for scene in self.scene_names
|
114 |
+
]
|
115 |
+
self.data_root = data_root
|
116 |
+
|
117 |
+
def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=10, plot=False, method='sparse'):
|
118 |
+
with torch.no_grad():
|
119 |
+
data_root = self.data_root
|
120 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
121 |
+
thresholds = [5, 10, 20]
|
122 |
+
for scene_ind in range(len(self.scenes)):
|
123 |
+
import os
|
124 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
125 |
+
scene = self.scenes[scene_ind]
|
126 |
+
indices = scene['pair_info']
|
127 |
+
idx = 0
|
128 |
+
for pair in tqdm.tqdm(indices):
|
129 |
+
|
130 |
+
pairs = pair['pair_names']
|
131 |
+
K0 = pair['intrinsic'][0].copy().astype(np.float32)
|
132 |
+
T0 = pair['pose'][0].copy().astype(np.float32)
|
133 |
+
R0, t0 = T0[:3, :3], T0[:3, 3]
|
134 |
+
K1 = pair['intrinsic'][1].copy().astype(np.float32)
|
135 |
+
T1 = pair['pose'][1].copy().astype(np.float32)
|
136 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
137 |
+
R, t = compute_relative_pose(R0, t0, R1, t1)
|
138 |
+
T0_to_1 = np.concatenate((R,t[:,None]), axis=-1)
|
139 |
+
im_A_path = f"{data_root}/images/{pairs[0]}"
|
140 |
+
im_B_path = f"{data_root}/images/{pairs[1]}"
|
141 |
+
|
142 |
+
im_A = cv2.imread(im_A_path)
|
143 |
+
im_B = cv2.imread(im_B_path)
|
144 |
+
|
145 |
+
if method == 'dense':
|
146 |
+
kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600)
|
147 |
+
elif method == 'lightglue':
|
148 |
+
kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600)
|
149 |
+
elif method == 'sparse':
|
150 |
+
kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600)
|
151 |
+
else:
|
152 |
+
raise ValueError(f"Invalid method {method}")
|
153 |
+
|
154 |
+
im_A = Image.open(im_A_path)
|
155 |
+
w0, h0 = im_A.size
|
156 |
+
im_B = Image.open(im_B_path)
|
157 |
+
w1, h1 = im_B.size
|
158 |
+
if scale_intrinsics:
|
159 |
+
scale0 = 840 / max(w0, h0)
|
160 |
+
scale1 = 840 / max(w1, h1)
|
161 |
+
w0, h0 = scale0 * w0, scale0 * h0
|
162 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
163 |
+
K0, K1 = K0.copy(), K1.copy()
|
164 |
+
K0[:2] = K0[:2] * scale0
|
165 |
+
K1[:2] = K1[:2] * scale1
|
166 |
+
|
167 |
+
threshold = 0.5
|
168 |
+
if calibrated:
|
169 |
+
norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2])))
|
170 |
+
ret = estimate_pose(
|
171 |
+
kpts0,
|
172 |
+
kpts1,
|
173 |
+
K0,
|
174 |
+
K1,
|
175 |
+
norm_threshold,
|
176 |
+
conf=0.99999,
|
177 |
+
)
|
178 |
+
if ret is not None:
|
179 |
+
R_est, t_est, mask = ret
|
180 |
+
T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) #
|
181 |
+
T0_to_1 = np.concatenate((R, t[:,None]), axis=-1)
|
182 |
+
e_t, e_R = compute_pose_error(T0_to_1_est, R, t)
|
183 |
+
|
184 |
+
epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1)
|
185 |
+
if scene_ind % plot_every_iter == 0 and plot:
|
186 |
+
|
187 |
+
if not os.path.exists(f'outputs/air_ground/{model_name}_{method}'):
|
188 |
+
os.mkdir(f'outputs/air_ground/{model_name}_{method}')
|
189 |
+
name = f'outputs/air_ground/{model_name}_{method}/{scene_name}_{idx}.png'
|
190 |
+
_make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name)
|
191 |
+
e_pose = max(e_t, e_R)
|
192 |
+
|
193 |
+
tot_e_t.append(e_t)
|
194 |
+
tot_e_R.append(e_R)
|
195 |
+
tot_e_pose.append(e_pose)
|
196 |
+
idx += 1
|
197 |
+
|
198 |
+
tot_e_pose = np.array(tot_e_pose)
|
199 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
200 |
+
acc_5 = (tot_e_pose < 5).mean()
|
201 |
+
acc_10 = (tot_e_pose < 10).mean()
|
202 |
+
acc_15 = (tot_e_pose < 15).mean()
|
203 |
+
acc_20 = (tot_e_pose < 20).mean()
|
204 |
+
map_5 = acc_5
|
205 |
+
map_10 = np.mean([acc_5, acc_10])
|
206 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
207 |
+
print(f"{model_name} auc: {auc}")
|
208 |
+
return {
|
209 |
+
"auc_5": auc[0],
|
210 |
+
"auc_10": auc[1],
|
211 |
+
"auc_20": auc[2],
|
212 |
+
"map_5": map_5,
|
213 |
+
"map_10": map_10,
|
214 |
+
"map_20": map_20,
|
215 |
+
}
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
def parse_arguments():
|
220 |
+
parser = argparse.ArgumentParser(description="Testing script.")
|
221 |
+
|
222 |
+
parser.add_argument("--data_root", type=str, default="./data/air_ground", help="Path to the Air-to-Ground test dataset.")
|
223 |
+
|
224 |
+
parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.")
|
225 |
+
|
226 |
+
parser.add_argument("--plot", action="store_true", help="Whether to plot the results.")
|
227 |
+
|
228 |
+
parser.add_argument("--method", type=str, default="sparse", help="Method for matching.")
|
229 |
+
|
230 |
+
return parser.parse_args()
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
args = parse_arguments()
|
234 |
+
|
235 |
+
if not os.path.exists('outputs'):
|
236 |
+
os.mkdir('outputs')
|
237 |
+
if not os.path.exists(f'outputs/air_ground'):
|
238 |
+
os.mkdir(f'outputs/air_ground')
|
239 |
+
model = build(weights=args.weights)
|
240 |
+
benchmark = AirGroundPoseMNNBenchmark(data_root=args.data_root)
|
241 |
+
model.eval()
|
242 |
+
model_helper = RDD_helper(model)
|
243 |
+
with torch.no_grad():
|
244 |
+
method = args.method
|
245 |
+
out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method)
|
246 |
+
with open(f'outputs/air_ground/RDD_{method}.txt', 'w') as f:
|
247 |
+
f.write(str(out))
|
imcui/third_party/rdd/benchmarks/mega_1500.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import tqdm
|
7 |
+
import cv2
|
8 |
+
import argparse
|
9 |
+
from RDD.RDD_helper import RDD_helper
|
10 |
+
from RDD.RDD import build
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import matplotlib
|
13 |
+
import os
|
14 |
+
from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha
|
15 |
+
|
16 |
+
def make_matching_figure(
|
17 |
+
img0, img1, mkpts0, mkpts1, color,
|
18 |
+
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
|
19 |
+
# draw image pair
|
20 |
+
assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
21 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
22 |
+
axes[0].imshow(img0, cmap='gray')
|
23 |
+
axes[1].imshow(img1, cmap='gray')
|
24 |
+
for i in range(2): # clear all frames
|
25 |
+
axes[i].get_yaxis().set_ticks([])
|
26 |
+
axes[i].get_xaxis().set_ticks([])
|
27 |
+
for spine in axes[i].spines.values():
|
28 |
+
spine.set_visible(False)
|
29 |
+
plt.tight_layout(pad=1)
|
30 |
+
|
31 |
+
if kpts0 is not None:
|
32 |
+
assert kpts1 is not None
|
33 |
+
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
|
34 |
+
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
|
35 |
+
|
36 |
+
# draw matches
|
37 |
+
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
38 |
+
fig.canvas.draw()
|
39 |
+
transFigure = fig.transFigure.inverted()
|
40 |
+
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
41 |
+
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
42 |
+
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
43 |
+
(fkpts0[i, 1], fkpts1[i, 1]),
|
44 |
+
transform=fig.transFigure, c=color[i], linewidth=1)
|
45 |
+
for i in range(len(mkpts0))]
|
46 |
+
|
47 |
+
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
48 |
+
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
49 |
+
|
50 |
+
# put txts
|
51 |
+
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
|
52 |
+
fig.text(
|
53 |
+
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
|
54 |
+
fontsize=15, va='top', ha='left', color=txt_color)
|
55 |
+
|
56 |
+
# save or return figure
|
57 |
+
if path:
|
58 |
+
plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
|
59 |
+
plt.close()
|
60 |
+
else:
|
61 |
+
return fig
|
62 |
+
|
63 |
+
def error_colormap(err, thr, alpha=1.0):
|
64 |
+
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
65 |
+
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
66 |
+
return np.clip(
|
67 |
+
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
|
68 |
+
|
69 |
+
def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None):
|
70 |
+
conf_thr = 1e-4
|
71 |
+
|
72 |
+
img0 = np.array(img0)
|
73 |
+
img1 = np.array(img1)
|
74 |
+
|
75 |
+
kpts0 = kpts0
|
76 |
+
kpts1 = kpts1
|
77 |
+
|
78 |
+
epi_errs = epi_errs.cpu().numpy()
|
79 |
+
correct_mask = epi_errs < conf_thr
|
80 |
+
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
81 |
+
n_correct = np.sum(correct_mask)
|
82 |
+
|
83 |
+
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
84 |
+
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
85 |
+
|
86 |
+
# matching info
|
87 |
+
if alpha == 'dynamic':
|
88 |
+
alpha = dynamic_alpha(len(correct_mask))
|
89 |
+
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
90 |
+
|
91 |
+
text = [
|
92 |
+
f'#Matches {len(kpts0)}',
|
93 |
+
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
|
94 |
+
f'e_t: {e_t:.2f} | e_R: {e_R:.2f}',
|
95 |
+
]
|
96 |
+
|
97 |
+
# make the figure
|
98 |
+
figure = make_matching_figure(img0, img1, kpts0, kpts1,
|
99 |
+
color, text=text, path=path)
|
100 |
+
return figure
|
101 |
+
|
102 |
+
class MegaDepthPoseMNNBenchmark:
|
103 |
+
def __init__(self, data_root="./megadepth_test_1500", scene_names = None) -> None:
|
104 |
+
if scene_names is None:
|
105 |
+
self.scene_names = [
|
106 |
+
"0015_0.1_0.3.npz",
|
107 |
+
"0015_0.3_0.5.npz",
|
108 |
+
"0022_0.1_0.3.npz",
|
109 |
+
"0022_0.3_0.5.npz",
|
110 |
+
"0022_0.5_0.7.npz",
|
111 |
+
]
|
112 |
+
|
113 |
+
else:
|
114 |
+
self.scene_names = scene_names
|
115 |
+
self.scenes = [
|
116 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
117 |
+
for scene in self.scene_names
|
118 |
+
]
|
119 |
+
self.data_root = data_root
|
120 |
+
|
121 |
+
def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=1, plot=False, method='sparse'):
|
122 |
+
|
123 |
+
with torch.no_grad():
|
124 |
+
data_root = self.data_root
|
125 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
126 |
+
thresholds = [5, 10, 20]
|
127 |
+
for scene_ind in range(len(self.scenes)):
|
128 |
+
import os
|
129 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
130 |
+
print(f"Processing {scene_name}")
|
131 |
+
scene = self.scenes[scene_ind]
|
132 |
+
pairs = scene["pair_infos"]
|
133 |
+
intrinsics = scene["intrinsics"]
|
134 |
+
poses = scene["poses"]
|
135 |
+
im_paths = scene["image_paths"]
|
136 |
+
pair_inds = range(len(pairs))
|
137 |
+
for pairind in tqdm.tqdm(pair_inds):
|
138 |
+
idx0, idx1 = pairs[pairind][0]
|
139 |
+
K0 = intrinsics[idx0].copy()
|
140 |
+
T0 = poses[idx0].copy()
|
141 |
+
R0, t0 = T0[:3, :3], T0[:3, 3]
|
142 |
+
K1 = intrinsics[idx1].copy()
|
143 |
+
T1 = poses[idx1].copy()
|
144 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
145 |
+
R, t = compute_relative_pose(R0, t0, R1, t1)
|
146 |
+
T0_to_1 = np.concatenate((R,t[:,None]), axis=-1)
|
147 |
+
im_A_path = f"{data_root}/{im_paths[idx0]}"
|
148 |
+
im_B_path = f"{data_root}/{im_paths[idx1]}"
|
149 |
+
|
150 |
+
im_A = cv2.imread(im_A_path)
|
151 |
+
im_B = cv2.imread(im_B_path)
|
152 |
+
|
153 |
+
if method == 'dense':
|
154 |
+
kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600)
|
155 |
+
elif method == 'lightglue':
|
156 |
+
kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600)
|
157 |
+
elif method == 'sparse':
|
158 |
+
kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600)
|
159 |
+
else:
|
160 |
+
kpts0, kpts1, conf = model_helper.match_3rd_party(im_A, im_B, thr=0.01, resize=1600, model=method)
|
161 |
+
|
162 |
+
im_A = Image.open(im_A_path)
|
163 |
+
w0, h0 = im_A.size
|
164 |
+
im_B = Image.open(im_B_path)
|
165 |
+
w1, h1 = im_B.size
|
166 |
+
if scale_intrinsics:
|
167 |
+
scale0 = 840 / max(w0, h0)
|
168 |
+
scale1 = 840 / max(w1, h1)
|
169 |
+
w0, h0 = scale0 * w0, scale0 * h0
|
170 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
171 |
+
K0, K1 = K0.copy(), K1.copy()
|
172 |
+
K0[:2] = K0[:2] * scale0
|
173 |
+
K1[:2] = K1[:2] * scale1
|
174 |
+
|
175 |
+
|
176 |
+
threshold = 0.5
|
177 |
+
if calibrated:
|
178 |
+
norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2])))
|
179 |
+
ret = estimate_pose(
|
180 |
+
kpts0,
|
181 |
+
kpts1,
|
182 |
+
K0,
|
183 |
+
K1,
|
184 |
+
norm_threshold,
|
185 |
+
conf=0.99999,
|
186 |
+
)
|
187 |
+
if ret is not None:
|
188 |
+
R_est, t_est, mask = ret
|
189 |
+
T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) #
|
190 |
+
T0_to_1 = np.concatenate((R, t[:,None]), axis=-1)
|
191 |
+
e_t, e_R = compute_pose_error(T0_to_1_est, R, t)
|
192 |
+
|
193 |
+
epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1)
|
194 |
+
if scene_ind % plot_every_iter == 0 and plot:
|
195 |
+
|
196 |
+
if not os.path.exists(f'outputs/mega_1500/{model_name}_{method}'):
|
197 |
+
os.mkdir(f'outputs/mega_1500/{model_name}_{method}')
|
198 |
+
name = f'outputs/mega_1500/{model_name}_{method}/{scene_name}_{pairind}.png'
|
199 |
+
_make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name)
|
200 |
+
e_pose = max(e_t, e_R)
|
201 |
+
|
202 |
+
tot_e_t.append(e_t)
|
203 |
+
tot_e_R.append(e_R)
|
204 |
+
tot_e_pose.append(e_pose)
|
205 |
+
|
206 |
+
tot_e_pose = np.array(tot_e_pose)
|
207 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
208 |
+
acc_5 = (tot_e_pose < 5).mean()
|
209 |
+
acc_10 = (tot_e_pose < 10).mean()
|
210 |
+
acc_15 = (tot_e_pose < 15).mean()
|
211 |
+
acc_20 = (tot_e_pose < 20).mean()
|
212 |
+
map_5 = acc_5
|
213 |
+
map_10 = np.mean([acc_5, acc_10])
|
214 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
215 |
+
print(f"{model_name} auc: {auc}")
|
216 |
+
return {
|
217 |
+
"auc_5": auc[0],
|
218 |
+
"auc_10": auc[1],
|
219 |
+
"auc_20": auc[2],
|
220 |
+
"map_5": map_5,
|
221 |
+
"map_10": map_10,
|
222 |
+
"map_20": map_20,
|
223 |
+
}
|
224 |
+
|
225 |
+
|
226 |
+
def parse_arguments():
|
227 |
+
parser = argparse.ArgumentParser(description="Testing script.")
|
228 |
+
|
229 |
+
parser.add_argument("--data_root", type=str, default="./data/megadepth_test_1500", help="Path to the MegaDepth dataset.")
|
230 |
+
|
231 |
+
parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.")
|
232 |
+
|
233 |
+
parser.add_argument("--plot", action="store_true", help="Whether to plot the results.")
|
234 |
+
|
235 |
+
parser.add_argument("--method", type=str, default="sparse", help="Method for matching.")
|
236 |
+
|
237 |
+
return parser.parse_args()
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
args = parse_arguments()
|
241 |
+
if not os.path.exists('outputs'):
|
242 |
+
os.mkdir('outputs')
|
243 |
+
|
244 |
+
if not os.path.exists(f'outputs/mega_1500'):
|
245 |
+
os.mkdir(f'outputs/mega_1500')
|
246 |
+
|
247 |
+
model = build(weights=args.weights)
|
248 |
+
benchmark = MegaDepthPoseMNNBenchmark(data_root=args.data_root)
|
249 |
+
model.eval()
|
250 |
+
model_helper = RDD_helper(model)
|
251 |
+
with torch.no_grad():
|
252 |
+
method = args.method
|
253 |
+
out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method)
|
254 |
+
with open(f'outputs/mega_1500/RDD_{method}.txt', 'w') as f:
|
255 |
+
f.write(str(out))
|
imcui/third_party/rdd/benchmarks/mega_view.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import tqdm
|
7 |
+
import cv2
|
8 |
+
import argparse
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib
|
11 |
+
from RDD.RDD import build
|
12 |
+
from RDD.RDD_helper import RDD_helper
|
13 |
+
import os
|
14 |
+
from benchmarks.utils import pose_auc, angle_error_vec, angle_error_mat, symmetric_epipolar_distance, compute_symmetrical_epipolar_errors, compute_pose_error, compute_relative_pose, estimate_pose, dynamic_alpha
|
15 |
+
|
16 |
+
def make_matching_figure(
|
17 |
+
img0, img1, mkpts0, mkpts1, color,
|
18 |
+
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
|
19 |
+
# draw image pair
|
20 |
+
assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
21 |
+
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
22 |
+
axes[0].imshow(img0, cmap='gray')
|
23 |
+
axes[1].imshow(img1, cmap='gray')
|
24 |
+
for i in range(2): # clear all frames
|
25 |
+
axes[i].get_yaxis().set_ticks([])
|
26 |
+
axes[i].get_xaxis().set_ticks([])
|
27 |
+
for spine in axes[i].spines.values():
|
28 |
+
spine.set_visible(False)
|
29 |
+
plt.tight_layout(pad=1)
|
30 |
+
|
31 |
+
if kpts0 is not None:
|
32 |
+
assert kpts1 is not None
|
33 |
+
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
|
34 |
+
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
|
35 |
+
|
36 |
+
# draw matches
|
37 |
+
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
38 |
+
fig.canvas.draw()
|
39 |
+
transFigure = fig.transFigure.inverted()
|
40 |
+
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
41 |
+
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
42 |
+
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
43 |
+
(fkpts0[i, 1], fkpts1[i, 1]),
|
44 |
+
transform=fig.transFigure, c=color[i], linewidth=1)
|
45 |
+
for i in range(len(mkpts0))]
|
46 |
+
|
47 |
+
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
48 |
+
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
49 |
+
|
50 |
+
# put txts
|
51 |
+
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
|
52 |
+
fig.text(
|
53 |
+
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
|
54 |
+
fontsize=15, va='top', ha='left', color=txt_color)
|
55 |
+
|
56 |
+
# save or return figure
|
57 |
+
if path:
|
58 |
+
plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
|
59 |
+
plt.close()
|
60 |
+
else:
|
61 |
+
return fig
|
62 |
+
|
63 |
+
def error_colormap(err, thr, alpha=1.0):
|
64 |
+
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
65 |
+
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
66 |
+
return np.clip(
|
67 |
+
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
|
68 |
+
|
69 |
+
def _make_evaluation_figure(img0, img1, kpts0, kpts1, epi_errs, e_t, e_R, alpha='dynamic', path=None):
|
70 |
+
conf_thr = 1e-4
|
71 |
+
|
72 |
+
img0 = np.array(img0)
|
73 |
+
img1 = np.array(img1)
|
74 |
+
|
75 |
+
kpts0 = kpts0
|
76 |
+
kpts1 = kpts1
|
77 |
+
|
78 |
+
epi_errs = epi_errs.cpu().numpy()
|
79 |
+
correct_mask = epi_errs < conf_thr
|
80 |
+
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
81 |
+
n_correct = np.sum(correct_mask)
|
82 |
+
|
83 |
+
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
84 |
+
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
85 |
+
|
86 |
+
# matching info
|
87 |
+
if alpha == 'dynamic':
|
88 |
+
alpha = dynamic_alpha(len(correct_mask))
|
89 |
+
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
90 |
+
|
91 |
+
text = [
|
92 |
+
f'#Matches {len(kpts0)}',
|
93 |
+
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
|
94 |
+
f'e_t: {e_t:.2f} | e_R: {e_R:.2f}',
|
95 |
+
]
|
96 |
+
|
97 |
+
# make the figure
|
98 |
+
figure = make_matching_figure(img0, img1, kpts0, kpts1,
|
99 |
+
color, text=text, path=path)
|
100 |
+
return figure
|
101 |
+
|
102 |
+
class MegaDepthPoseMNNBenchmark:
|
103 |
+
def __init__(self, data_root="./megadepth_test_1500", scene_names = None) -> None:
|
104 |
+
if scene_names is None:
|
105 |
+
self.scene_names = [
|
106 |
+
"hard_indices.npz",
|
107 |
+
]
|
108 |
+
# self.scene_names = ["0022_0.5_0.7.npz",]
|
109 |
+
else:
|
110 |
+
self.scene_names = scene_names
|
111 |
+
self.scenes = [
|
112 |
+
np.load(f"{data_root}/{scene}", allow_pickle=True)
|
113 |
+
for scene in self.scene_names
|
114 |
+
]
|
115 |
+
self.data_root = data_root
|
116 |
+
|
117 |
+
def benchmark(self, model_helper, model_name = None, scale_intrinsics = False, calibrated = True, plot_every_iter=1, plot=False, method='sparse'):
|
118 |
+
with torch.no_grad():
|
119 |
+
data_root = self.data_root
|
120 |
+
tot_e_t, tot_e_R, tot_e_pose = [], [], []
|
121 |
+
thresholds = [5, 10, 20]
|
122 |
+
for scene_ind in range(len(self.scenes)):
|
123 |
+
scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
|
124 |
+
scene = self.scenes[scene_ind]
|
125 |
+
indices = scene['indices']
|
126 |
+
idx = 0
|
127 |
+
|
128 |
+
for pair in tqdm.tqdm(indices):
|
129 |
+
|
130 |
+
pairs = pair['pair_names']
|
131 |
+
K0 = pair['intrisinic'][0].copy().astype(np.float32)
|
132 |
+
T0 = pair['pose'][0].copy().astype(np.float32)
|
133 |
+
R0, t0 = T0[:3, :3], T0[:3, 3]
|
134 |
+
K1 = pair['intrisinic'][1].copy().astype(np.float32)
|
135 |
+
T1 = pair['pose'][1].copy().astype(np.float32)
|
136 |
+
R1, t1 = T1[:3, :3], T1[:3, 3]
|
137 |
+
R, t = compute_relative_pose(R0, t0, R1, t1)
|
138 |
+
T0_to_1 = np.concatenate((R,t[:,None]), axis=-1)
|
139 |
+
im_A_path = f"{data_root}/images/{pairs[0]}"
|
140 |
+
im_B_path = f"{data_root}/images/{pairs[1]}"
|
141 |
+
|
142 |
+
im_A = cv2.imread(im_A_path)
|
143 |
+
im_B = cv2.imread(im_B_path)
|
144 |
+
|
145 |
+
if method == 'dense':
|
146 |
+
kpts0, kpts1, conf = model_helper.match_dense(im_A, im_B, thr=0.01, resize=1600)
|
147 |
+
elif method == 'lightglue':
|
148 |
+
kpts0, kpts1, conf = model_helper.match_lg(im_A, im_B, thr=0.01, resize=1600)
|
149 |
+
elif method == 'sparse':
|
150 |
+
kpts0, kpts1, conf = model_helper.match(im_A, im_B, thr=0.01, resize=1600)
|
151 |
+
else:
|
152 |
+
raise ValueError(f"Invalid method {method}")
|
153 |
+
|
154 |
+
im_A = Image.open(im_A_path)
|
155 |
+
w0, h0 = im_A.size
|
156 |
+
im_B = Image.open(im_B_path)
|
157 |
+
w1, h1 = im_B.size
|
158 |
+
|
159 |
+
if scale_intrinsics:
|
160 |
+
scale0 = 840 / max(w0, h0)
|
161 |
+
scale1 = 840 / max(w1, h1)
|
162 |
+
w0, h0 = scale0 * w0, scale0 * h0
|
163 |
+
w1, h1 = scale1 * w1, scale1 * h1
|
164 |
+
K0, K1 = K0.copy(), K1.copy()
|
165 |
+
K0[:2] = K0[:2] * scale0
|
166 |
+
K1[:2] = K1[:2] * scale1
|
167 |
+
|
168 |
+
threshold = 0.5
|
169 |
+
if calibrated:
|
170 |
+
norm_threshold = threshold / (np.mean(np.abs(K0[:2, :2])) + np.mean(np.abs(K1[:2, :2])))
|
171 |
+
ret = estimate_pose(
|
172 |
+
kpts0,
|
173 |
+
kpts1,
|
174 |
+
K0,
|
175 |
+
K1,
|
176 |
+
norm_threshold,
|
177 |
+
conf=0.99999,
|
178 |
+
)
|
179 |
+
if ret is not None:
|
180 |
+
R_est, t_est, mask = ret
|
181 |
+
T0_to_1_est = np.concatenate((R_est, t_est), axis=-1) #
|
182 |
+
T0_to_1 = np.concatenate((R, t[:,None]), axis=-1)
|
183 |
+
e_t, e_R = compute_pose_error(T0_to_1_est, R, t)
|
184 |
+
|
185 |
+
epi_errs = compute_symmetrical_epipolar_errors(T0_to_1, kpts0, kpts1, K0, K1)
|
186 |
+
if scene_ind % plot_every_iter == 0 and plot:
|
187 |
+
|
188 |
+
if not os.path.exists(f'outputs/mega_view/{model_name}_{method}'):
|
189 |
+
os.mkdir(f'outputs/mega_view/{model_name}_{method}')
|
190 |
+
name = f'outputs/mega_view/{model_name}_{method}/{scene_name}_{idx}.png'
|
191 |
+
_make_evaluation_figure(im_A, im_B, kpts0, kpts1, epi_errs, e_t, e_R, path=name)
|
192 |
+
e_pose = max(e_t, e_R)
|
193 |
+
|
194 |
+
tot_e_t.append(e_t)
|
195 |
+
tot_e_R.append(e_R)
|
196 |
+
tot_e_pose.append(e_pose)
|
197 |
+
idx += 1
|
198 |
+
|
199 |
+
tot_e_pose = np.array(tot_e_pose)
|
200 |
+
auc = pose_auc(tot_e_pose, thresholds)
|
201 |
+
acc_5 = (tot_e_pose < 5).mean()
|
202 |
+
acc_10 = (tot_e_pose < 10).mean()
|
203 |
+
acc_15 = (tot_e_pose < 15).mean()
|
204 |
+
acc_20 = (tot_e_pose < 20).mean()
|
205 |
+
map_5 = acc_5
|
206 |
+
map_10 = np.mean([acc_5, acc_10])
|
207 |
+
map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
|
208 |
+
print(f"{model_name} auc: {auc}")
|
209 |
+
return {
|
210 |
+
"auc_5": auc[0],
|
211 |
+
"auc_10": auc[1],
|
212 |
+
"auc_20": auc[2],
|
213 |
+
"map_5": map_5,
|
214 |
+
"map_10": map_10,
|
215 |
+
"map_20": map_20,
|
216 |
+
}
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
def parse_arguments():
|
221 |
+
parser = argparse.ArgumentParser(description="Testing script.")
|
222 |
+
|
223 |
+
parser.add_argument("--data_root", type=str, default="./data/megadepth_view", help="Path to the MegaDepth dataset.")
|
224 |
+
|
225 |
+
parser.add_argument("--weights", type=str, default="./weights/RDD-v2.pth", help="Path to the model checkpoint.")
|
226 |
+
|
227 |
+
parser.add_argument("--plot", action="store_true", help="Whether to plot the results.")
|
228 |
+
|
229 |
+
parser.add_argument("--method", type=str, default="sparse", help="Method for matching.")
|
230 |
+
|
231 |
+
return parser.parse_args()
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
args = parse_arguments()
|
235 |
+
if not os.path.exists('outputs'):
|
236 |
+
os.mkdir('outputs')
|
237 |
+
|
238 |
+
if not os.path.exists(f'outputs/mega_view'):
|
239 |
+
os.mkdir(f'outputs/mega_view')
|
240 |
+
model = build(weights=args.weights)
|
241 |
+
benchmark = MegaDepthPoseMNNBenchmark(data_root=args.data_root)
|
242 |
+
model.eval()
|
243 |
+
model_helper = RDD_helper(model)
|
244 |
+
with torch.no_grad():
|
245 |
+
method = args.method
|
246 |
+
out = benchmark.benchmark(model_helper, model_name='RDD', plot_every_iter=1, plot=args.plot, method=method)
|
247 |
+
with open(f'outputs/mega_view/RDD_{method}.txt', 'w') as f:
|
248 |
+
f.write(str(out))
|
249 |
+
|
250 |
+
|
imcui/third_party/rdd/benchmarks/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from kornia.geometry.epipolar import numeric
|
4 |
+
from kornia.geometry.conversions import convert_points_to_homogeneous
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
def pose_auc(errors, thresholds):
|
8 |
+
sort_idx = np.argsort(errors)
|
9 |
+
errors = np.array(errors.copy())[sort_idx]
|
10 |
+
recall = (np.arange(len(errors)) + 1) / len(errors)
|
11 |
+
errors = np.r_[0.0, errors]
|
12 |
+
recall = np.r_[0.0, recall]
|
13 |
+
aucs = []
|
14 |
+
for t in thresholds:
|
15 |
+
last_index = np.searchsorted(errors, t)
|
16 |
+
r = np.r_[recall[:last_index], recall[last_index - 1]]
|
17 |
+
e = np.r_[errors[:last_index], t]
|
18 |
+
aucs.append(np.trapz(r, x=e) / t)
|
19 |
+
return aucs
|
20 |
+
|
21 |
+
def angle_error_vec(v1, v2):
|
22 |
+
n = np.linalg.norm(v1) * np.linalg.norm(v2)
|
23 |
+
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
|
24 |
+
|
25 |
+
def angle_error_mat(R1, R2):
|
26 |
+
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
|
27 |
+
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
|
28 |
+
return np.rad2deg(np.abs(np.arccos(cos)))
|
29 |
+
|
30 |
+
def symmetric_epipolar_distance(pts0, pts1, E, K0, K1):
|
31 |
+
"""Squared symmetric epipolar distance.
|
32 |
+
This can be seen as a biased estimation of the reprojection error.
|
33 |
+
Args:
|
34 |
+
pts0 (torch.Tensor): [N, 2]
|
35 |
+
E (torch.Tensor): [3, 3]
|
36 |
+
"""
|
37 |
+
pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
|
38 |
+
pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
|
39 |
+
pts0 = convert_points_to_homogeneous(pts0)
|
40 |
+
pts1 = convert_points_to_homogeneous(pts1)
|
41 |
+
|
42 |
+
Ep0 = pts0 @ E.T # [N, 3]
|
43 |
+
p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,]
|
44 |
+
Etp1 = pts1 @ E # [N, 3]
|
45 |
+
|
46 |
+
d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N
|
47 |
+
return d
|
48 |
+
|
49 |
+
def compute_symmetrical_epipolar_errors(T_0to1, pts0, pts1, K0, K1, device='cuda'):
|
50 |
+
"""
|
51 |
+
Update:
|
52 |
+
data (dict):{"epi_errs": [M]}
|
53 |
+
"""
|
54 |
+
pts0 = torch.tensor(pts0, device=device)
|
55 |
+
pts1 = torch.tensor(pts1, device=device)
|
56 |
+
K0 = torch.tensor(K0, device=device)
|
57 |
+
K1 = torch.tensor(K1, device=device)
|
58 |
+
T_0to1 = torch.tensor(T_0to1, device=device)
|
59 |
+
Tx = numeric.cross_product_matrix(T_0to1[:3, 3])
|
60 |
+
E_mat = Tx @ T_0to1[:3, :3]
|
61 |
+
|
62 |
+
epi_err = symmetric_epipolar_distance(pts0, pts1, E_mat, K0, K1)
|
63 |
+
return epi_err
|
64 |
+
|
65 |
+
def compute_pose_error(T_0to1, R, t):
|
66 |
+
R_gt = T_0to1[:3, :3]
|
67 |
+
t_gt = T_0to1[:3, 3]
|
68 |
+
error_t = angle_error_vec(t.squeeze(), t_gt)
|
69 |
+
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
|
70 |
+
error_R = angle_error_mat(R, R_gt)
|
71 |
+
return error_t, error_R
|
72 |
+
|
73 |
+
def compute_relative_pose(R1, t1, R2, t2):
|
74 |
+
rots = R2 @ (R1.T)
|
75 |
+
trans = -rots @ t1 + t2
|
76 |
+
return rots, trans
|
77 |
+
|
78 |
+
def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
|
79 |
+
if len(kpts0) < 5:
|
80 |
+
return None
|
81 |
+
K0inv = np.linalg.inv(K0[:2,:2])
|
82 |
+
K1inv = np.linalg.inv(K1[:2,:2])
|
83 |
+
|
84 |
+
kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
|
85 |
+
kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
|
86 |
+
E, mask = cv2.findEssentialMat(
|
87 |
+
kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
|
88 |
+
)
|
89 |
+
|
90 |
+
ret = None
|
91 |
+
if E is not None:
|
92 |
+
best_num_inliers = 0
|
93 |
+
|
94 |
+
for _E in np.split(E, len(E) / 3):
|
95 |
+
n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
|
96 |
+
if n > best_num_inliers:
|
97 |
+
best_num_inliers = n
|
98 |
+
ret = (R, t, mask.ravel() > 0)
|
99 |
+
return ret
|
100 |
+
|
101 |
+
def dynamic_alpha(n_matches,
|
102 |
+
milestones=[0, 300, 1000, 2000],
|
103 |
+
alphas=[1.0, 0.8, 0.4, 0.2]):
|
104 |
+
if n_matches == 0:
|
105 |
+
return 1.0
|
106 |
+
ranges = list(zip(alphas, alphas[1:] + [None]))
|
107 |
+
loc = bisect.bisect_right(milestones, n_matches) - 1
|
108 |
+
_range = ranges[loc]
|
109 |
+
if _range[1] is None:
|
110 |
+
return _range[0]
|
111 |
+
return _range[1] + (milestones[loc + 1] - n_matches) / (
|
112 |
+
milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
|
imcui/third_party/rdd/configs/default.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
activation: relu
|
2 |
+
block_dims:
|
3 |
+
- 8
|
4 |
+
- 16
|
5 |
+
- 32
|
6 |
+
- 64
|
7 |
+
d_model: 256
|
8 |
+
detection_threshold: 0.1
|
9 |
+
device: cuda
|
10 |
+
dim_feedforward: 1024
|
11 |
+
dropout: 0.1
|
12 |
+
enc_n_points: 8
|
13 |
+
hidden_dim: 256
|
14 |
+
lr_backbone: 2.0e-05
|
15 |
+
nhead: 8
|
16 |
+
num_encoder_layers: 4
|
17 |
+
num_feature_levels: 5
|
18 |
+
top_k: 4096
|
19 |
+
train_detector: False
|